Unverified Commit fd71b11b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

move is_sm90_supported/is_sm100_supported to python/sglang/srt/utils.py (#9679)

parent ae7428a8
...@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend ...@@ -26,11 +26,14 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2 from sglang.srt.utils import (
is_flashinfer_available,
is_sm100_supported,
next_power_of_2,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
......
...@@ -28,11 +28,14 @@ from sglang.srt.layers.attention.flashinfer_backend import ( ...@@ -28,11 +28,14 @@ from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import is_flashinfer_available, next_power_of_2 from sglang.srt.utils import (
is_flashinfer_available,
is_sm100_supported,
next_power_of_2,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
......
...@@ -40,10 +40,9 @@ from sglang.srt.layers.moe import ( ...@@ -40,10 +40,9 @@ from sglang.srt.layers.moe import (
get_moe_a2a_backend, get_moe_a2a_backend,
should_use_flashinfer_cutlass_moe_fp4_allgather, should_use_flashinfer_cutlass_moe_fp4_allgather,
) )
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda, is_flashinfer_available from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
_is_flashinfer_available = is_flashinfer_available() _is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported() _is_sm100_supported = is_cuda() and is_sm100_supported()
......
"""CUTLASS based Fused MoE kernels.""" """CUTLASS based Fused MoE kernels."""
import functools
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
import sgl_kernel
from sgl_kernel import ( from sgl_kernel import (
apply_shuffle_mul_sum, apply_shuffle_mul_sum,
cutlass_fp4_group_mm, cutlass_fp4_group_mm,
......
...@@ -64,7 +64,6 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -64,7 +64,6 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize, per_tensor_dequantize,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
...@@ -72,6 +71,8 @@ from sglang.srt.utils import ( ...@@ -72,6 +71,8 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_hip, is_hip,
is_npu, is_npu,
is_sm90_supported,
is_sm100_supported,
log_info_on_rank0, log_info_on_rank0,
next_power_of_2, next_power_of_2,
print_warning_once, print_warning_once,
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.utils import is_sm100_supported
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
......
...@@ -29,14 +29,13 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -29,14 +29,13 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var,
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_sm100_supported,
is_triton_kernels_available, is_triton_kernels_available,
log_info_on_rank0, log_info_on_rank0,
mxfp_supported, mxfp_supported,
......
...@@ -34,17 +34,3 @@ class PPMissingLayer(torch.nn.Identity): ...@@ -34,17 +34,3 @@ class PPMissingLayer(torch.nn.Identity):
""" """
input = args[0] if args else next(iter(kwargs.values())) input = args[0] if args else next(iter(kwargs.values()))
return (input,) if self.return_tuple else input return (input,) if self.return_tuple else input
@lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
@lru_cache(maxsize=1)
def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
...@@ -66,7 +66,6 @@ from sglang.srt.layers.quantization import ( ...@@ -66,7 +66,6 @@ from sglang.srt.layers.quantization import (
) )
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -121,6 +120,7 @@ from sglang.srt.utils import ( ...@@ -121,6 +120,7 @@ from sglang.srt.utils import (
is_hopper_with_cuda_12_3, is_hopper_with_cuda_12_3,
is_no_spec_infer_or_topk_one, is_no_spec_infer_or_topk_one,
is_npu, is_npu,
is_sm100_supported,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
set_cuda_arch, set_cuda_arch,
......
...@@ -87,8 +87,8 @@ from sglang.srt.layers.quantization.int8_utils import ( ...@@ -87,8 +87,8 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant, block_dequant as int8_block_dequant,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -114,6 +114,7 @@ from sglang.srt.utils import ( ...@@ -114,6 +114,7 @@ from sglang.srt.utils import (
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_non_idle_and_non_empty, is_non_idle_and_non_empty,
is_sm100_supported,
log_info_on_rank0, log_info_on_rank0,
make_layers, make_layers,
use_intel_amx_backend, use_intel_amx_backend,
......
...@@ -58,7 +58,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -58,7 +58,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4 from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -71,6 +71,7 @@ from sglang.srt.utils import ( ...@@ -71,6 +71,7 @@ from sglang.srt.utils import (
add_prefix, add_prefix,
is_cuda, is_cuda,
is_flashinfer_available, is_flashinfer_available,
is_sm100_supported,
make_layers, make_layers,
) )
......
...@@ -25,7 +25,6 @@ from typing import List, Literal, Optional, Union ...@@ -25,7 +25,6 @@ from typing import List, Literal, Optional, Union
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.layers.utils import is_sm90_supported, is_sm100_supported
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -39,6 +38,8 @@ from sglang.srt.utils import ( ...@@ -39,6 +38,8 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_port_available, is_port_available,
is_remote_url, is_remote_url,
is_sm90_supported,
is_sm100_supported,
is_triton_kernels_available, is_triton_kernels_available,
is_valid_ipv6_address, is_valid_ipv6_address,
nullable_str, nullable_str,
......
...@@ -172,6 +172,20 @@ def is_blackwell(): ...@@ -172,6 +172,20 @@ def is_blackwell():
return torch.cuda.get_device_capability()[0] == 10 return torch.cuda.get_device_capability()[0] == 10
@lru_cache(maxsize=1)
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
@lru_cache(maxsize=1)
def is_sm90_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 9) and (
torch.version.cuda >= "12.3"
)
_warned_bool_env_var_keys = set() _warned_bool_env_var_keys = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment