Unverified Commit 97cb762b authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[misc] remove is_cuda_available (#5319)

parent 11951820
......@@ -28,9 +28,9 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda_available, set_weight_attrs
from sglang.srt.utils import is_cuda, set_weight_attrs
_is_cuda = is_cuda_available()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
......
......@@ -3,10 +3,10 @@ import triton
import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_cuda, is_hip
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
_is_cuda = is_cuda()
if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
_is_hip = is_hip()
......@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
num_warps = 4
else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
......
......@@ -23,10 +23,10 @@ import triton.language as tl
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_cuda, is_hip
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
_is_cuda = is_cuda()
if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
_is_hip = is_hip()
......@@ -345,12 +345,12 @@ def extend_attention_fwd(
num_warps = 4
else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128:
......
......@@ -22,8 +22,12 @@ import torch
import triton
import triton.language as tl
is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda or _is_hip:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
......@@ -172,7 +176,7 @@ def context_attention_fwd(
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
BLOCK = 64
......
......@@ -20,9 +20,9 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda_available()
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import (
......
......@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda
if is_cuda_available():
if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
# Initialize logger for the module
......
......@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda_available, set_weight_attrs
from sglang.srt.utils import is_cuda, set_weight_attrs
is_cuda = is_cuda_available()
if is_cuda:
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
......
......@@ -8,11 +8,11 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda
_is_cuda_available = is_cuda_available()
_is_cuda = is_cuda()
if _is_cuda_available:
if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else:
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
......@@ -82,7 +82,7 @@ class RotaryEmbedding(CustomOp):
cache = self._compute_cos_sin_cache()
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda_available:
if not _is_cuda:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
......@@ -149,7 +149,7 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
......@@ -652,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda_available:
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
......
......@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
if is_cuda_available():
if is_cuda():
from sgl_kernel import (
min_p_sampling_from_probs,
top_k_renorm_prob,
......
......@@ -40,9 +40,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available
from sglang.srt.utils import add_prefix, is_cuda
if is_cuda_available():
if is_cuda():
from sgl_kernel import bmm_fp8
......
......@@ -4,9 +4,9 @@ from typing import List
import torch
from sglang.srt.utils import is_cuda_available, is_hip
from sglang.srt.utils import is_cuda, is_hip
if is_cuda_available() or is_hip():
if is_cuda() or is_hip():
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
......
......@@ -19,9 +19,9 @@ from sglang.srt.managers.schedule_batch import (
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2
if is_cuda_available():
if is_cuda():
from sgl_kernel import (
top_k_renorm_prob,
top_p_renorm_prob,
......
......@@ -34,14 +34,9 @@ from sglang.srt.speculative.eagle_utils import (
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
fast_topk,
get_available_gpu_memory,
is_cuda_available,
)
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda
if is_cuda_available():
if is_cuda():
from sgl_kernel import segment_packbits
logger = logging.getLogger(__name__)
......
......@@ -130,10 +130,6 @@ def is_flashinfer_available():
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
def is_cuda_available():
return is_cuda()
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment