Unverified Commit 8abf74e3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename files in sgl kernel to avoid nested folder structure (#4213)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent ee132a45
from typing import Optional, Tuple, Union
import sgl_kernel.ops._kernels
import torch
from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
def _top_k_renorm_probs_internal(
......@@ -13,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper(
torch.ops.sgl_kernel.top_k_renorm_probs_wrapper(
probs,
renorm_probs,
maybe_top_k_arr,
......@@ -41,7 +40,7 @@ def _top_p_renorm_probs_internal(
probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernels.top_p_renorm_probs(
torch.ops.sgl_kernel.top_p_renorm_probs(
probs,
renorm_probs,
maybe_top_p_arr,
......@@ -76,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernels.top_p_sampling_from_probs(
torch.ops.sgl_kernel.top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
......@@ -122,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs(
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
samples,
......@@ -180,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
)
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernels.min_p_sampling_from_probs(
torch.ops.sgl_kernel.min_p_sampling_from_probs(
probs,
uniform_samples,
samples,
......
import sgl_kernel.ops._kernels
import torch
from sgl_kernel.ops.utils import get_cuda_stream
from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only(
......@@ -16,7 +15,7 @@ def tree_speculative_sampling_target_only(
draft_probs: torch.Tensor,
deterministic: bool = True,
) -> None:
torch.ops.sgl_kernels.tree_speculative_sampling_target_only(
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
predicts,
accept_index,
accept_token_num,
......@@ -45,7 +44,7 @@ def build_tree_kernel_efficient(
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernels.build_tree_kernel_efficient(
torch.ops.sgl_kernel.build_tree_kernel_efficient(
parent_list,
selected_index,
verified_seq_len,
......@@ -71,7 +70,7 @@ def build_tree_kernel(
depth: int,
draft_token_num: int,
) -> None:
torch.ops.sgl_kernels.build_tree_kernel(
torch.ops.sgl_kernel.build_tree_kernel(
parent_list,
selected_index,
verified_seq_len,
......
......@@ -48,16 +48,16 @@ def _get_version():
return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernels"
operator_namespace = "sgl_kernel"
cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer"
turbomind = root / "3rdparty" / "turbomind"
include_dirs = [
root / "include",
root / "csrc",
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "include",
root / "src" / "sgl-kernel" / "csrc",
flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc",
......@@ -96,21 +96,21 @@ nvcc_flags_fp8 = [
]
sources = [
"src/sgl-kernel/torch_extension.cc",
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
"csrc/allreduce/trt_reduce_internal.cu",
"csrc/allreduce/trt_reduce_kernel.cu",
"csrc/attention/lightning_attention_decode_kernel.cu",
"csrc/elementwise/fused_add_rms_norm_kernel.cu",
"csrc/gemm/cublas_grouped_gemm.cu",
"csrc/gemm/fp8_gemm_kernel.cu",
"csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"csrc/gemm/int8_gemm_kernel.cu",
"csrc/gemm/per_token_group_quant_fp8.cu",
"csrc/gemm/per_token_quant_fp8.cu",
"csrc/gemm/per_tensor_quant_fp8.cu",
"csrc/moe/moe_align_kernel.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/speculative/speculative_sampling.cu",
"csrc/torch_extension.cc",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu",
......@@ -158,7 +158,7 @@ extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linu
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
name="sgl_kernel.common_ops",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
......@@ -174,8 +174,8 @@ ext_modules = [
setup(
name="sgl-kernel",
version=_get_version(),
packages=find_packages(),
package_dir={"": "src"},
packages=find_packages(where="python"),
package_dir={"": "python"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
......
......@@ -13,12 +13,9 @@
# limitations under the License.
# ==============================================================================
import multiprocessing
import os
import sys
from pathlib import Path
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
......@@ -35,16 +32,16 @@ def _get_version():
return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernels"
operator_namespace = "sgl_kernel"
include_dirs = [
root / "src" / "sgl-kernel" / "include",
root / "src" / "sgl-kernel" / "csrc",
root / "include",
root / "csrc",
]
sources = [
"src/sgl-kernel/torch_extension_rocm.cc",
"src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"csrc/allreduce/custom_all_reduce.hip",
"csrc/moe/moe_align_kernel.cu",
"csrc/torch_extension_rocm.cc",
]
cxx_flags = ["-O3"]
......@@ -64,26 +61,27 @@ hipcc_flags = [
"-DENABLE_FP8",
]
ext_modules = [
CUDAExtension(
name="sgl_kernel.common_ops",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": hipcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
]
setup(
name="sgl-kernel",
version=_get_version(),
packages=find_packages(),
package_dir={"": "src"},
ext_modules=[
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": hipcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
],
package_dir={"": "python"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
options={"bdist_wheel": {"py_limited_api": "cp39"}},
install_requires=["torch"],
)
......@@ -7,7 +7,7 @@ import unittest
from typing import Any, List, Optional
import ray
import sgl_kernel.ops.allreduce as custom_ops
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment