"git@developer.sourcefind.cn:change/sglang.git" did not exist on "8dc191f237550d81eba9137f1e04b499ed8a6742"
Unverified Commit 97cb762b authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[misc] remove is_cuda_available (#5319)

parent 11951820
...@@ -28,9 +28,9 @@ from sglang.srt.distributed import ( ...@@ -28,9 +28,9 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda_available, set_weight_attrs from sglang.srt.utils import is_cuda, set_weight_attrs
_is_cuda = is_cuda_available() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
......
...@@ -3,10 +3,10 @@ import triton ...@@ -3,10 +3,10 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_hip from sglang.srt.utils import is_cuda, is_hip
is_cuda_available = torch.cuda.is_available() _is_cuda = is_cuda()
if is_cuda_available: if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
_is_hip = is_hip() _is_hip = is_hip()
...@@ -1037,12 +1037,12 @@ def extend_attention_fwd( ...@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
num_warps = 4 num_warps = 4
else: else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9: if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256: if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64) BLOCK_M, BLOCK_N = (128, 64)
else: else:
BLOCK_M, BLOCK_N = (32, 64) BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
if Lq <= 128: if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128) BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256: elif Lq <= 256:
......
...@@ -23,10 +23,10 @@ import triton.language as tl ...@@ -23,10 +23,10 @@ import triton.language as tl
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd, context_attention_fwd,
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import is_cuda, is_hip
is_cuda_available = torch.cuda.is_available() _is_cuda = is_cuda()
if is_cuda_available: if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
_is_hip = is_hip() _is_hip = is_hip()
...@@ -345,12 +345,12 @@ def extend_attention_fwd( ...@@ -345,12 +345,12 @@ def extend_attention_fwd(
num_warps = 4 num_warps = 4
else: else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9: if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256: if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64) BLOCK_M, BLOCK_N = (128, 64)
else: else:
BLOCK_M, BLOCK_N = (32, 64) BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K) # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6: if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128: if Lq <= 128:
......
...@@ -22,8 +22,12 @@ import torch ...@@ -22,8 +22,12 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
is_cuda_available = torch.cuda.is_available() from sglang.srt.utils import is_cuda, is_hip
if is_cuda_available:
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda or _is_hip:
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
...@@ -172,7 +176,7 @@ def context_attention_fwd( ...@@ -172,7 +176,7 @@ def context_attention_fwd(
b_seq_len: [b] b_seq_len: [b]
out: [b * s, head, head_dim] out: [b * s, head, head_dim]
""" """
if is_cuda_available and CUDA_CAPABILITY[0] > 8: if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
BLOCK = 128 BLOCK = 128
else: else:
BLOCK = 64 BLOCK = 64
......
...@@ -20,9 +20,9 @@ import torch ...@@ -20,9 +20,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda
_is_cuda = is_cuda_available() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
......
...@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import ( ...@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda
if is_cuda_available(): if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
# Initialize logger for the module # Initialize logger for the module
......
...@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda_available, set_weight_attrs from sglang.srt.utils import is_cuda, set_weight_attrs
is_cuda = is_cuda_available() _is_cuda = is_cuda()
if is_cuda: if _is_cuda:
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
......
...@@ -8,11 +8,11 @@ import torch ...@@ -8,11 +8,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import is_cuda
_is_cuda_available = is_cuda_available() _is_cuda = is_cuda()
if _is_cuda_available: if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else: else:
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
...@@ -82,7 +82,7 @@ class RotaryEmbedding(CustomOp): ...@@ -82,7 +82,7 @@ class RotaryEmbedding(CustomOp):
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda_available: if not _is_cuda:
cache = cache.to(dtype) cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False) self.register_buffer("cos_sin_cache", cache, persistent=False)
...@@ -149,7 +149,7 @@ class RotaryEmbedding(CustomOp): ...@@ -149,7 +149,7 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]): if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace( apply_rope_with_cos_sin_cache_inplace(
positions=positions, positions=positions,
query=query, query=query,
...@@ -652,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -652,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if torch.compiler.is_compiling(): if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs) return self.forward_native(*args, **kwargs)
if _is_cuda_available: if _is_cuda:
return self.forward_cuda(*args, **kwargs) return self.forward_cuda(*args, **kwargs)
else: else:
return self.forward_native(*args, **kwargs) return self.forward_native(*args, **kwargs)
......
...@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group ...@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
if is_cuda_available(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
min_p_sampling_from_probs, min_p_sampling_from_probs,
top_k_renorm_prob, top_k_renorm_prob,
......
...@@ -40,9 +40,9 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -40,9 +40,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available from sglang.srt.utils import add_prefix, is_cuda
if is_cuda_available(): if is_cuda():
from sgl_kernel import bmm_fp8 from sgl_kernel import bmm_fp8
......
...@@ -4,9 +4,9 @@ from typing import List ...@@ -4,9 +4,9 @@ from typing import List
import torch import torch
from sglang.srt.utils import is_cuda_available, is_hip from sglang.srt.utils import is_cuda, is_hip
if is_cuda_available() or is_hip(): if is_cuda() or is_hip():
from sgl_kernel import ( from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
) )
......
...@@ -19,9 +19,9 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -19,9 +19,9 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2 from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
if is_cuda_available(): if is_cuda():
from sgl_kernel import ( from sgl_kernel import (
top_k_renorm_prob, top_k_renorm_prob,
top_p_renorm_prob, top_p_renorm_prob,
......
...@@ -34,14 +34,9 @@ from sglang.srt.speculative.eagle_utils import ( ...@@ -34,14 +34,9 @@ from sglang.srt.speculative.eagle_utils import (
select_top_k_tokens, select_top_k_tokens,
) )
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import ( from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
empty_context,
fast_topk,
get_available_gpu_memory,
is_cuda_available,
)
if is_cuda_available(): if is_cuda():
from sgl_kernel import segment_packbits from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -130,10 +130,6 @@ def is_flashinfer_available(): ...@@ -130,10 +130,6 @@ def is_flashinfer_available():
return importlib.util.find_spec("flashinfer") is not None and is_cuda() return importlib.util.find_spec("flashinfer") is not None and is_cuda()
def is_cuda_available():
return is_cuda()
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var( _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment