__init__.py 1.77 KB
Newer Older
1
2
import ctypes
import os
3
import platform
4

5
6
import torch

7
8
9
10
11
SYSTEM_ARCH = platform.machine()

cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
if os.path.exists(cuda_path):
    ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
12

13
14
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
15
16
17
18
from sgl_kernel.attention import (
    cutlass_mla_decode,
    cutlass_mla_get_workspace_size,
    lightning_attention_decode,
Yineng Zhang's avatar
Yineng Zhang committed
19
    merge_state,
20
    merge_state_v2,
21
)
22
from sgl_kernel.elementwise import (
23
24
25
26
27
28
29
30
31
    apply_rope_with_cos_sin_cache_inplace,
    fused_add_rmsnorm,
    gelu_and_mul,
    gelu_tanh_and_mul,
    gemma_fused_add_rmsnorm,
    gemma_rmsnorm,
    rmsnorm,
    silu_and_mul,
)
32
from sgl_kernel.gemm import (
33
    awq_dequantize,
34
    bmm_fp8,
Trevor Morris's avatar
Trevor Morris committed
35
    cutlass_scaled_fp4_mm,
36
37
38
    fp8_blockwise_scaled_mm,
    fp8_scaled_mm,
    int8_scaled_mm,
HandH1998's avatar
HandH1998 committed
39
40
    qserve_w4a8_per_chn_gemm,
    qserve_w4a8_per_group_gemm,
Trevor Morris's avatar
Trevor Morris committed
41
    scaled_fp4_quant,
42
    sgl_per_tensor_quant_fp8,
43
    sgl_per_token_group_quant_fp8,
44
    sgl_per_token_group_quant_int8,
45
    sgl_per_token_quant_fp8,
46
)
47
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
48
from sgl_kernel.moe import (
49
    ep_moe_pre_reorder,
50
51
52
    fp8_blockwise_scaled_grouped_mm,
    moe_align_block_size,
    moe_fused_gate,
53
    prepare_moe_input,
54
55
    topk_softmax,
)
56
from sgl_kernel.sampling import (
57
58
59
60
61
62
    min_p_sampling_from_probs,
    top_k_renorm_prob,
    top_k_top_p_sampling_from_probs,
    top_p_renorm_prob,
    top_p_sampling_from_probs,
)
63
from sgl_kernel.speculative import (
64
    build_tree_kernel_efficient,
65
    segment_packbits,
66
    tree_speculative_sampling_target_only,
67
    verify_tree_greedy,
68
)
Lianmin Zheng's avatar
Lianmin Zheng committed
69
from sgl_kernel.version import __version__
70
71

build_tree_kernel = (
72
    None  # TODO(ying): remove this after updating the sglang python code.
73
)