__init__.py 3.68 KB
Newer Older
1
import torch
Lianmin Zheng's avatar
Lianmin Zheng committed
2
from sgl_kernel.load_utils import _load_architecture_specific_ops, _preload_cuda_library
3
4
5

# Initialize the ops library based on current GPU
common_ops = _load_architecture_specific_ops()
EduardDurech's avatar
EduardDurech committed
6

Lianmin Zheng's avatar
Lianmin Zheng committed
7
# Preload the CUDA library to avoid the issue of libcudart.so.12 not found
8
if torch.version.cuda is not None:
Lianmin Zheng's avatar
Lianmin Zheng committed
9
    _preload_cuda_library()
EduardDurech's avatar
EduardDurech committed
10

11

12
from sgl_kernel.allreduce import *
13
14
15
16
from sgl_kernel.attention import (
    cutlass_mla_decode,
    cutlass_mla_get_workspace_size,
    lightning_attention_decode,
Yineng Zhang's avatar
Yineng Zhang committed
17
    merge_state,
18
    merge_state_v2,
19
)
20
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
21
from sgl_kernel.elementwise import (
22
    FusedSetKVBufferArg,
23
    apply_rope_with_cos_sin_cache_inplace,
24
    concat_mla_absorb_q,
25
    concat_mla_k,
26
    copy_to_gpu_no_ce,
27
    downcast_fp8,
28
29
30
31
32
33
34
35
    fused_add_rmsnorm,
    gelu_and_mul,
    gelu_tanh_and_mul,
    gemma_fused_add_rmsnorm,
    gemma_rmsnorm,
    rmsnorm,
    silu_and_mul,
)
36
from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm
37
from sgl_kernel.fused_moe import fused_marlin_moe
38
from sgl_kernel.gemm import (
39
    awq_dequantize,
40
    bmm_fp8,
Trevor Morris's avatar
Trevor Morris committed
41
    cutlass_scaled_fp4_mm,
42
    dsv3_fused_a_gemm,
43
    dsv3_router_gemm,
44
45
    fp8_blockwise_scaled_mm,
    fp8_scaled_mm,
46
47
48
    gptq_gemm,
    gptq_marlin_gemm,
    gptq_shuffle,
49
    int8_scaled_mm,
HandH1998's avatar
HandH1998 committed
50
51
    qserve_w4a8_per_chn_gemm,
    qserve_w4a8_per_group_gemm,
52
    scaled_fp4_experts_quant,
53
    scaled_fp4_grouped_quant,
Trevor Morris's avatar
Trevor Morris committed
54
    scaled_fp4_quant,
55
    sgl_per_tensor_quant_fp8,
56
    sgl_per_token_group_quant_8bit,
57
58
    sgl_per_token_group_quant_fp8,
    sgl_per_token_group_quant_int8,
59
    sgl_per_token_quant_fp8,
60
    shuffle_rows,
61
    silu_and_mul_scaled_fp4_grouped_quant,
62
)
63
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
64
65
66
67
68
69
70
from sgl_kernel.hadamard import (
    hadamard_transform,
    hadamard_transform_12n,
    hadamard_transform_20n,
    hadamard_transform_28n,
    hadamard_transform_40n,
)
71
72
73
74
75
76
from sgl_kernel.kvcacheio import (
    transfer_kv_all_layer,
    transfer_kv_all_layer_mla,
    transfer_kv_per_layer,
    transfer_kv_per_layer_mla,
)
77
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
78
79
80
81
82
from sgl_kernel.marlin import (
    awq_marlin_moe_repack,
    awq_marlin_repack,
    gptq_marlin_repack,
)
83
from sgl_kernel.memory import set_kv_buffer_kernel
84
from sgl_kernel.moe import (
85
    apply_shuffle_mul_sum,
86
    cutlass_fp4_group_mm,
87
88
89
    fp8_blockwise_scaled_grouped_mm,
    moe_align_block_size,
    moe_fused_gate,
90
    moe_sum,
91
    moe_sum_reduce,
92
    prepare_moe_input,
93
94
    topk_softmax,
)
95
96
97
98
99
100
101
102
from sgl_kernel.quantization import (
    ggml_dequantize,
    ggml_moe_a8,
    ggml_moe_a8_vec,
    ggml_moe_get_block_size,
    ggml_mul_mat_a8,
    ggml_mul_mat_vec_a8,
)
103
from sgl_kernel.sampling import (
104
    min_p_sampling_from_probs,
105
    top_k_mask_logits,
106
    top_k_renorm_prob,
107
    top_k_top_p_sampling_from_logits,
108
109
110
111
    top_k_top_p_sampling_from_probs,
    top_p_renorm_prob,
    top_p_sampling_from_probs,
)
112
113
from sgl_kernel.speculative import (
    build_tree_kernel_efficient,
114
    reconstruct_indices_from_tree_mask,
115
116
117
118
    segment_packbits,
    tree_speculative_sampling_target_only,
    verify_tree_greedy,
)
119
120
121
122
123
124
from sgl_kernel.top_k import (
    fast_topk,
    fast_topk_transform_fused,
    fast_topk_transform_ragged_fused,
    fast_topk_v2,
)
125
from sgl_kernel.version import __version__
126

127
128
129
if torch.version.hip is not None:
    from sgl_kernel.elementwise import gelu_quick

130
131
132
133
134
135
136
137
138
139
140

def create_greenctx_stream_by_value(*args, **kwargs):
    from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl

    return _impl(*args, **kwargs)


def get_sm_available(*args, **kwargs):
    from sgl_kernel.spatial import get_sm_available as _impl

    return _impl(*args, **kwargs)