"vscode:/vscode.git/clone" did not exist on "0b1a843d321fe08343c738dab4bf01e81520c16b"
__init__.py 3.97 KB
Newer Older
1
2
import ctypes
import os
3
import platform
EduardDurech's avatar
EduardDurech committed
4
5
import shutil
from pathlib import Path
6

7
8
import torch

9

EduardDurech's avatar
EduardDurech committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# copy & modify from torch/utils/cpp_extension.py
def _find_cuda_home():
    """Find the CUDA install path."""
    # Guess #1
    cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
    if cuda_home is None:
        # Guess #2
        nvcc_path = shutil.which("nvcc")
        if nvcc_path is not None:
            cuda_home = os.path.dirname(os.path.dirname(nvcc_path))
        else:
            # Guess #3
            cuda_home = "/usr/local/cuda"
    return cuda_home


26
if torch.version.cuda is not None:
EduardDurech's avatar
EduardDurech committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
    cuda_home = Path(_find_cuda_home())

    if (cuda_home / "lib").is_dir():
        cuda_path = cuda_home / "lib"
    elif (cuda_home / "lib64").is_dir():
        cuda_path = cuda_home / "lib64"
    else:
        # Search for 'libcudart.so.12' in subdirectories
        for path in cuda_home.rglob("libcudart.so.12"):
            cuda_path = path.parent
            break
        else:
            raise RuntimeError("Could not find CUDA lib directory.")

    cuda_include = (cuda_path / "libcudart.so.12").resolve()
    if cuda_include.exists():
        ctypes.CDLL(str(cuda_include), mode=ctypes.RTLD_GLOBAL)
44

45
46
from sgl_kernel import common_ops
from sgl_kernel.allreduce import *
47
48
49
50
from sgl_kernel.attention import (
    cutlass_mla_decode,
    cutlass_mla_get_workspace_size,
    lightning_attention_decode,
Yineng Zhang's avatar
Yineng Zhang committed
51
    merge_state,
52
    merge_state_v2,
53
)
54
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
55
from sgl_kernel.elementwise import (
56
    FusedSetKVBufferArg,
57
    apply_rope_with_cos_sin_cache_inplace,
58
    concat_mla_absorb_q,
59
    concat_mla_k,
60
    copy_to_gpu_no_ce,
61
    downcast_fp8,
62
63
64
65
66
67
68
69
    fused_add_rmsnorm,
    gelu_and_mul,
    gelu_tanh_and_mul,
    gemma_fused_add_rmsnorm,
    gemma_rmsnorm,
    rmsnorm,
    silu_and_mul,
)
70
from sgl_kernel.fused_moe import fused_marlin_moe
71
from sgl_kernel.gemm import (
72
    awq_dequantize,
73
    bmm_fp8,
Trevor Morris's avatar
Trevor Morris committed
74
    cutlass_scaled_fp4_mm,
75
    dsv3_fused_a_gemm,
76
    dsv3_router_gemm,
77
78
    fp8_blockwise_scaled_mm,
    fp8_scaled_mm,
79
80
81
    gptq_gemm,
    gptq_marlin_gemm,
    gptq_shuffle,
82
    int8_scaled_mm,
HandH1998's avatar
HandH1998 committed
83
84
    qserve_w4a8_per_chn_gemm,
    qserve_w4a8_per_group_gemm,
85
    scaled_fp4_experts_quant,
86
    scaled_fp4_grouped_quant,
Trevor Morris's avatar
Trevor Morris committed
87
    scaled_fp4_quant,
88
    sgl_per_tensor_quant_fp8,
89
90
    sgl_per_token_group_quant_fp8,
    sgl_per_token_group_quant_int8,
91
    sgl_per_token_quant_fp8,
92
    shuffle_rows,
93
    silu_and_mul_scaled_fp4_grouped_quant,
94
)
95
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
96
97
98
99
100
101
from sgl_kernel.kvcacheio import (
    transfer_kv_all_layer,
    transfer_kv_all_layer_mla,
    transfer_kv_per_layer,
    transfer_kv_per_layer_mla,
)
102
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
103
104
105
106
107
from sgl_kernel.marlin import (
    awq_marlin_moe_repack,
    awq_marlin_repack,
    gptq_marlin_repack,
)
108
from sgl_kernel.memory import set_kv_buffer_kernel
109
from sgl_kernel.moe import (
110
    apply_shuffle_mul_sum,
111
    cutlass_fp4_group_mm,
112
113
114
    fp8_blockwise_scaled_grouped_mm,
    moe_align_block_size,
    moe_fused_gate,
115
    prepare_moe_input,
116
117
    topk_softmax,
)
118
from sgl_kernel.sampling import (
119
    min_p_sampling_from_probs,
120
    top_k_mask_logits,
121
    top_k_renorm_prob,
122
    top_k_top_p_sampling_from_logits,
123
124
125
126
    top_k_top_p_sampling_from_probs,
    top_p_renorm_prob,
    top_p_sampling_from_probs,
)
127
128
129
130
131
132
133
134
from sgl_kernel.speculative import (
    build_tree_kernel_efficient,
    segment_packbits,
    tree_speculative_sampling_target_only,
    verify_tree_greedy,
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
135

136
137
138
if torch.version.hip is not None:
    from sgl_kernel.elementwise import gelu_quick

139
140
141
142
143
144
145
146
147
148
149

def create_greenctx_stream_by_value(*args, **kwargs):
    from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl

    return _impl(*args, **kwargs)


def get_sm_available(*args, **kwargs):
    from sgl_kernel.spatial import get_sm_available as _impl

    return _impl(*args, **kwargs)