"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "fae4e5e99a93f8f5e7fa462833754c91ecbea1c2"
__init__.py 2.98 KB
Newer Older
1
2
import ctypes
import os
3
import platform
4

5
6
import torch

7
8
9
10
11
SYSTEM_ARCH = platform.machine()

cuda_path = f"/usr/local/cuda/targets/{SYSTEM_ARCH}-linux/lib/libcudart.so.12"
if os.path.exists(cuda_path):
    ctypes.CDLL(cuda_path, mode=ctypes.RTLD_GLOBAL)
12

13
14
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
15
16
17
18
from sgl_kernel.attention import (
    cutlass_mla_decode,
    cutlass_mla_get_workspace_size,
    lightning_attention_decode,
Yineng Zhang's avatar
Yineng Zhang committed
19
    merge_state,
20
    merge_state_v2,
21
)
22
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
23
from sgl_kernel.elementwise import (
24
    FusedSetKVBufferArg,
25
    apply_rope_with_cos_sin_cache_inplace,
26
    concat_mla_k,
27
    copy_to_gpu_no_ce,
28
    downcast_fp8,
29
30
31
32
33
34
35
36
    fused_add_rmsnorm,
    gelu_and_mul,
    gelu_tanh_and_mul,
    gemma_fused_add_rmsnorm,
    gemma_rmsnorm,
    rmsnorm,
    silu_and_mul,
)
Yi Zhang's avatar
Yi Zhang committed
37
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
38
39
40
41

if torch.version.hip is not None:
    from sgl_kernel.elementwise import gelu_quick

42
from sgl_kernel.fused_moe import fused_marlin_moe
43
from sgl_kernel.gemm import (
44
    awq_dequantize,
45
    bmm_fp8,
Trevor Morris's avatar
Trevor Morris committed
46
    cutlass_scaled_fp4_mm,
47
    dsv3_fused_a_gemm,
48
    dsv3_router_gemm,
49
50
    fp8_blockwise_scaled_mm,
    fp8_scaled_mm,
51
52
53
    gptq_gemm,
    gptq_marlin_gemm,
    gptq_shuffle,
54
    int8_scaled_mm,
HandH1998's avatar
HandH1998 committed
55
56
    qserve_w4a8_per_chn_gemm,
    qserve_w4a8_per_group_gemm,
57
    scaled_fp4_experts_quant,
58
    scaled_fp4_grouped_quant,
Trevor Morris's avatar
Trevor Morris committed
59
    scaled_fp4_quant,
60
    sgl_per_tensor_quant_fp8,
61
62
    sgl_per_token_group_quant_fp8,
    sgl_per_token_group_quant_int8,
63
    sgl_per_token_quant_fp8,
64
    shuffle_rows,
65
    silu_and_mul_scaled_fp4_grouped_quant,
66
)
67
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
68
69
70
71
72
73
from sgl_kernel.kvcacheio import (
    transfer_kv_all_layer,
    transfer_kv_all_layer_mla,
    transfer_kv_per_layer,
    transfer_kv_per_layer_mla,
)
74
75
76
77
78
from sgl_kernel.marlin import (
    awq_marlin_moe_repack,
    awq_marlin_repack,
    gptq_marlin_repack,
)
79
from sgl_kernel.memory import set_kv_buffer_kernel
80
from sgl_kernel.moe import (
81
    apply_shuffle_mul_sum,
82
    cutlass_fp4_group_mm,
83
84
85
    fp8_blockwise_scaled_grouped_mm,
    moe_align_block_size,
    moe_fused_gate,
86
    prepare_moe_input,
87
88
    topk_softmax,
)
89
from sgl_kernel.sampling import (
90
    min_p_sampling_from_probs,
91
    top_k_mask_logits,
92
    top_k_renorm_prob,
93
    top_k_top_p_sampling_from_logits,
94
95
96
97
    top_k_top_p_sampling_from_probs,
    top_p_renorm_prob,
    top_p_sampling_from_probs,
)
98
99
100
101
102
103
104
105
from sgl_kernel.speculative import (
    build_tree_kernel_efficient,
    segment_packbits,
    tree_speculative_sampling_target_only,
    verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
106
107
108
109
110
111
112
113
114
115
116
117


def create_greenctx_stream_by_value(*args, **kwargs):
    from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl

    return _impl(*args, **kwargs)


def get_sm_available(*args, **kwargs):
    from sgl_kernel.spatial import get_sm_available as _impl

    return _impl(*args, **kwargs)