"git@developer.sourcefind.cn:change/sglang.git" did not exist on "72b6ea88b4354ad7551aab1594db0c967065c11d"
__init__.py 3.96 KB
Newer Older
1
2
import ctypes
import os
3
import platform
EduardDurech's avatar
EduardDurech committed
4
5
import shutil
from pathlib import Path
6

7
8
import torch

9

EduardDurech's avatar
EduardDurech committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
    """Find the CUDA install path."""
    # Guess #1
    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
    if cuda_home is None:
        # Guess #2
        nvcc_path = shutil.which("nvcc")
        if nvcc_path is not None:
            cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
        else:
            # Guess #3
            cuda_home = "/usr/local/cuda"
    return cuda_home


if torch.version.hip is None:
    cuda_home = Path(_find_cuda_home())

    if (cuda_home / "lib").is_dir():
        cuda_path = cuda_home / "lib"
    elif (cuda_home / "lib64").is_dir():
        cuda_path = cuda_home / "lib64"
    else:
        # Search for 'libcudart.so.12' in subdirectories
        for path in cuda_home.rglob("libcudart.so.12"):
            cuda_path = path.parent
            break
        else:
            raise RuntimeError("Could not find CUDA lib directory.")

    cuda_include = (cuda_path / "libcudart.so.12").resolve()
    if cuda_include.exists():
        ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
44

45
46
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
47
48
49
50
from sgl_kernel.attention import (
    cutlass_mla_decode,
    cutlass_mla_get_workspace_size,
    lightning_attention_decode,
Yineng Zhang's avatar
Yineng Zhang committed
51
    merge_state,
52
    merge_state_v2,
53
)
54
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
55
from sgl_kernel.elementwise import (
56
    FusedSetKVBufferArg,
57
    apply_rope_with_cos_sin_cache_inplace,
58
    concat_mla_absorb_q,
59
    concat_mla_k,
60
    copy_to_gpu_no_ce,
61
    downcast_fp8,
62
63
64
65
66
67
68
69
    fused_add_rmsnorm,
    gelu_and_mul,
    gelu_tanh_and_mul,
    gemma_fused_add_rmsnorm,
    gemma_rmsnorm,
    rmsnorm,
    silu_and_mul,
)
70
from sgl_kernel.fused_moe import fused_marlin_moe
71
from sgl_kernel.gemm import (
72
    awq_dequantize,
73
    bmm_fp8,
Trevor Morris's avatar
Trevor Morris committed
74
    cutlass_scaled_fp4_mm,
75
    dsv3_fused_a_gemm,
76
    dsv3_router_gemm,
77
78
    fp8_blockwise_scaled_mm,
    fp8_scaled_mm,
79
80
81
    gptq_gemm,
    gptq_marlin_gemm,
    gptq_shuffle,
82
    int8_scaled_mm,
HandH1998's avatar
HandH1998 committed
83
84
    qserve_w4a8_per_chn_gemm,
    qserve_w4a8_per_group_gemm,
85
    scaled_fp4_experts_quant,
86
    scaled_fp4_grouped_quant,
Trevor Morris's avatar
Trevor Morris committed
87
    scaled_fp4_quant,
88
    sgl_per_tensor_quant_fp8,
89
90
    sgl_per_token_group_quant_fp8,
    sgl_per_token_group_quant_int8,
91
    sgl_per_token_quant_fp8,
92
    shuffle_rows,
93
    silu_and_mul_scaled_fp4_grouped_quant,
94
)
95
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
96
97
98
99
100
101
from sgl_kernel.kvcacheio import (
    transfer_kv_all_layer,
    transfer_kv_all_layer_mla,
    transfer_kv_per_layer,
    transfer_kv_per_layer_mla,
)
102
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
103
104
105
106
107
from sgl_kernel.marlin import (
    awq_marlin_moe_repack,
    awq_marlin_repack,
    gptq_marlin_repack,
)
108
from sgl_kernel.memory import set_kv_buffer_kernel
109
from sgl_kernel.moe import (
110
    apply_shuffle_mul_sum,
111
    cutlass_fp4_group_mm,
112
113
114
    fp8_blockwise_scaled_grouped_mm,
    moe_align_block_size,
    moe_fused_gate,
115
    prepare_moe_input,
116
117
    topk_softmax,
)
118
from sgl_kernel.sampling import (
119
    min_p_sampling_from_probs,
120
    top_k_mask_logits,
121
    top_k_renorm_prob,
122
    top_k_top_p_sampling_from_logits,
123
124
125
126
    top_k_top_p_sampling_from_probs,
    top_p_renorm_prob,
    top_p_sampling_from_probs,
)
127
128
129
130
131
132
133
134
from sgl_kernel.speculative import (
    build_tree_kernel_efficient,
    segment_packbits,
    tree_speculative_sampling_target_only,
    verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
135

136
137
138
if torch.version.hip is not None:
    from sgl_kernel.elementwise import gelu_quick

139
140
141
142
143
144
145
146
147
148
149

def create_greenctx_stream_by_value(*args, **kwargs):
    from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl

    return _impl(*args, **kwargs)


def get_sm_available(*args, **kwargs):
    from sgl_kernel.spatial import get_sm_available as _impl

    return _impl(*args, **kwargs)