Unverified Commit fd71b11b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

move is_sm90_supported/is_sm100_supported to python/sglang/srt/utils.py (#9679)

parent ae7428a8
......@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
from sglang.srt.utils import (
is_flashinfer_available,
is_sm100_supported,
next_power_of_2,
)
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......
......@@ -28,11 +28,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
from sglang.srt.utils import (
is_flashinfer_available,
is_sm100_supported,
next_power_of_2,
)
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......
......@@ -40,10 +40,9 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda, is_flashinfer_available
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
......
"""CUTLASS based Fused MoE kernels."""
import functools
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
import sgl_kernel
from sgl_kernel import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,
......
......@@ -64,7 +64,6 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize,
requantize_with_max_scale,
)
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
......@@ -72,6 +71,8 @@ from sglang.srt.utils import (
is_cuda,
is_hip,
is_npu,
is_sm90_supported,
is_sm100_supported,
log_info_on_rank0,
next_power_of_2,
print_warning_once,
......
......@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.utils import is_sm100_supported
try:
from vllm import _custom_ops as ops
......
......@@ -29,14 +29,13 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
is_cuda,
is_flashinfer_available,
is_hip,
is_sm100_supported,
is_triton_kernels_available,
log_info_on_rank0,
mxfp_supported,
......
......@@ -34,17 +34,3 @@ class PPMissingLayer(torch.nn.Identity):
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input,) if self.return_tuple else input
@lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
@lru_cache(maxsize=1)
def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
......@@ -66,7 +66,6 @@ from sglang.srt.layers.quantization import (
)
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import (
......@@ -121,6 +120,7 @@ from sglang.srt.utils import (
is_hopper_with_cuda_12_3,
is_no_spec_infer_or_topk_one,
is_npu,
is_sm100_supported,
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
set_cuda_arch,
......
......@@ -87,8 +87,8 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -114,6 +114,7 @@ from sglang.srt.utils import (
is_flashinfer_available,
is_hip,
is_non_idle_and_non_empty,
is_sm100_supported,
log_info_on_rank0,
make_layers,
use_intel_amx_backend,
......
......@@ -58,7 +58,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -71,6 +71,7 @@ from sglang.srt.utils import (
add_prefix,
is_cuda,
is_flashinfer_available,
is_sm100_supported,
make_layers,
)
......
......@@ -25,7 +25,6 @@ from typing import List, Literal, Optional, Union
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import (
......@@ -39,6 +38,8 @@ from sglang.srt.utils import (
is_hip,
is_port_available,
is_remote_url,
is_sm90_supported,
is_sm100_supported,
is_triton_kernels_available,
is_valid_ipv6_address,
nullable_str,
......
......@@ -172,6 +172,20 @@ def is_blackwell():
return torch.cuda.get_device_capability()[0] == 10
@lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
@lru_cache(maxsize=1)
def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
_warned_bool_env_var_keys = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment