"mmdet3d/models/vscode:/vscode.git/clone" did not exist on "6dd5d329cfa3e9a595cca81522e8d0b2080fe780"
Unverified Commit 8abf74e3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename files in sgl kernel to avoid nested folder structure (#4213)


Co-authored-by: default avatarzhyncs <me@zhyncs.com>
parent ee132a45
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import sgl_kernel.ops._kernels
import torch import torch
from sgl_kernel.ops.utils import _to_tensor_scalar_tuple, get_cuda_stream from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
def _top_k_renorm_probs_internal( def _top_k_renorm_probs_internal(
...@@ -13,7 +12,7 @@ def _top_k_renorm_probs_internal( ...@@ -13,7 +12,7 @@ def _top_k_renorm_probs_internal(
probs = probs.float() probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
renorm_probs = torch.empty_like(probs) renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( torch.ops.sgl_kernel.top_k_renorm_probs_wrapper(
probs, probs,
renorm_probs, renorm_probs,
maybe_top_k_arr, maybe_top_k_arr,
...@@ -41,7 +40,7 @@ def _top_p_renorm_probs_internal( ...@@ -41,7 +40,7 @@ def _top_p_renorm_probs_internal(
probs = probs.float() probs = probs.float()
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
renorm_probs = torch.empty_like(probs) renorm_probs = torch.empty_like(probs)
torch.ops.sgl_kernels.top_p_renorm_probs( torch.ops.sgl_kernel.top_p_renorm_probs(
probs, probs,
renorm_probs, renorm_probs,
maybe_top_p_arr, maybe_top_p_arr,
...@@ -76,7 +75,7 @@ def _top_p_sampling_from_probs_internal( ...@@ -76,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernels.top_p_sampling_from_probs( torch.ops.sgl_kernel.top_p_sampling_from_probs(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
...@@ -122,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal( ...@@ -122,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
success = torch.empty(probs.size(0), dtype=torch.bool, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
...@@ -180,7 +179,7 @@ def _min_p_sampling_from_probs_internal( ...@@ -180,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
) )
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
torch.ops.sgl_kernels.min_p_sampling_from_probs( torch.ops.sgl_kernel.min_p_sampling_from_probs(
probs, probs,
uniform_samples, uniform_samples,
samples, samples,
......
import sgl_kernel.ops._kernels
import torch import torch
from sgl_kernel.ops.utils import get_cuda_stream from sgl_kernel.utils import get_cuda_stream
def tree_speculative_sampling_target_only( def tree_speculative_sampling_target_only(
...@@ -16,7 +15,7 @@ def tree_speculative_sampling_target_only( ...@@ -16,7 +15,7 @@ def tree_speculative_sampling_target_only(
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
deterministic: bool = True, deterministic: bool = True,
) -> None: ) -> None:
torch.ops.sgl_kernels.tree_speculative_sampling_target_only( torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
predicts, predicts,
accept_index, accept_index,
accept_token_num, accept_token_num,
...@@ -45,7 +44,7 @@ def build_tree_kernel_efficient( ...@@ -45,7 +44,7 @@ def build_tree_kernel_efficient(
depth: int, depth: int,
draft_token_num: int, draft_token_num: int,
) -> None: ) -> None:
torch.ops.sgl_kernels.build_tree_kernel_efficient( torch.ops.sgl_kernel.build_tree_kernel_efficient(
parent_list, parent_list,
selected_index, selected_index,
verified_seq_len, verified_seq_len,
...@@ -71,7 +70,7 @@ def build_tree_kernel( ...@@ -71,7 +70,7 @@ def build_tree_kernel(
depth: int, depth: int,
draft_token_num: int, draft_token_num: int,
) -> None: ) -> None:
torch.ops.sgl_kernels.build_tree_kernel( torch.ops.sgl_kernel.build_tree_kernel(
parent_list, parent_list,
selected_index, selected_index,
verified_seq_len, verified_seq_len,
......
...@@ -48,16 +48,16 @@ def _get_version(): ...@@ -48,16 +48,16 @@ def _get_version():
return line.split("=")[1].strip().strip('"') return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernels" operator_namespace = "sgl_kernel"
cutlass_default = root / "3rdparty" / "cutlass" cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer" flashinfer = root / "3rdparty" / "flashinfer"
turbomind = root / "3rdparty" / "turbomind" turbomind = root / "3rdparty" / "turbomind"
include_dirs = [ include_dirs = [
root / "include",
root / "csrc",
cutlass.resolve() / "include", cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include", cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "include",
root / "src" / "sgl-kernel" / "csrc",
flashinfer.resolve() / "include", flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc", flashinfer.resolve() / "csrc",
...@@ -96,21 +96,21 @@ nvcc_flags_fp8 = [ ...@@ -96,21 +96,21 @@ nvcc_flags_fp8 = [
] ]
sources = [ sources = [
"src/sgl-kernel/torch_extension.cc", "csrc/allreduce/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu", "csrc/allreduce/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu", "csrc/attention/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu", "csrc/elementwise/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu", "csrc/gemm/cublas_grouped_gemm.cu",
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu", "csrc/gemm/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu", "csrc/gemm/fp8_blockwise_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", "csrc/gemm/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/int8_gemm_kernel.cu", "csrc/gemm/per_token_group_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu", "csrc/gemm/per_token_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu", "csrc/gemm/per_tensor_quant_fp8.cu",
"src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu", "csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu", "csrc/speculative/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative/eagle_utils.cu", "csrc/speculative/speculative_sampling.cu",
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu", "csrc/torch_extension.cc",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/norm.cu",
...@@ -158,7 +158,7 @@ extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linu ...@@ -158,7 +158,7 @@ extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linu
ext_modules = [ ext_modules = [
CUDAExtension( CUDAExtension(
name="sgl_kernel.ops._kernels", name="sgl_kernel.common_ops",
sources=sources, sources=sources,
include_dirs=include_dirs, include_dirs=include_dirs,
extra_compile_args={ extra_compile_args={
...@@ -174,8 +174,8 @@ ext_modules = [ ...@@ -174,8 +174,8 @@ ext_modules = [
setup( setup(
name="sgl-kernel", name="sgl-kernel",
version=_get_version(), version=_get_version(),
packages=find_packages(), packages=find_packages(where="python"),
package_dir={"": "src"}, package_dir={"": "python"},
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
options={"bdist_wheel": {"py_limited_api": "cp39"}}, options={"bdist_wheel": {"py_limited_api": "cp39"}},
......
...@@ -13,12 +13,9 @@ ...@@ -13,12 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import multiprocessing
import os
import sys import sys
from pathlib import Path from pathlib import Path
import torch
from setuptools import find_packages, setup from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
...@@ -35,16 +32,16 @@ def _get_version(): ...@@ -35,16 +32,16 @@ def _get_version():
return line.split("=")[1].strip().strip('"') return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernels" operator_namespace = "sgl_kernel"
include_dirs = [ include_dirs = [
root / "src" / "sgl-kernel" / "include", root / "include",
root / "src" / "sgl-kernel" / "csrc", root / "csrc",
] ]
sources = [ sources = [
"src/sgl-kernel/torch_extension_rocm.cc", "csrc/allreduce/custom_all_reduce.hip",
"src/sgl-kernel/csrc/allreduce/custom_all_reduce.hip", "csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu", "csrc/torch_extension_rocm.cc",
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
...@@ -64,26 +61,27 @@ hipcc_flags = [ ...@@ -64,26 +61,27 @@ hipcc_flags = [
"-DENABLE_FP8", "-DENABLE_FP8",
] ]
ext_modules = [
CUDAExtension(
name="sgl_kernel.common_ops",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": hipcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
]
setup( setup(
name="sgl-kernel", name="sgl-kernel",
version=_get_version(), version=_get_version(),
packages=find_packages(), packages=find_packages(),
package_dir={"": "src"}, package_dir={"": "python"},
ext_modules=[ ext_modules=ext_modules,
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"nvcc": hipcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
],
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)}, cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
options={"bdist_wheel": {"py_limited_api": "cp39"}}, options={"bdist_wheel": {"py_limited_api": "cp39"}},
install_requires=["torch"],
) )
...@@ -7,7 +7,7 @@ import unittest ...@@ -7,7 +7,7 @@ import unittest
from typing import Any, List, Optional from typing import Any, List, Optional
import ray import ray
import sgl_kernel.ops.allreduce as custom_ops import sgl_kernel.allreduce as custom_ops
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment