Unverified Commit 6a5b352a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use is_flashinfer_available to replace is_hip for flashinfer check (#1596)


Co-authored-by: default avatarZhang Liangang <liangang.zhang@intel.com>
parent 565b05f0
...@@ -20,9 +20,9 @@ import torch ...@@ -20,9 +20,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from sglang.srt.utils import is_hip from sglang.srt.utils import is_flashinfer_available
if not is_hip(): if is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import ( from vllm.distributed import (
...@@ -146,8 +146,8 @@ def get_act_fn( ...@@ -146,8 +146,8 @@ def get_act_fn(
return act_fn return act_fn
if is_hip(): if not is_flashinfer_available():
logger.info( logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries." "FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
) )
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
...@@ -19,13 +19,12 @@ from sglang.srt.layers.attention.flashinfer_utils import ( ...@@ -19,13 +19,12 @@ from sglang.srt.layers.attention.flashinfer_utils import (
update_flashinfer_indices, update_flashinfer_indices,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
# ROCm: flashinfer available later if is_flashinfer_available():
if not is_hip():
from flashinfer import ( from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper,
......
...@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union ...@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.utils import is_hip from sglang.srt.utils import is_flashinfer_available
if not is_hip(): if is_flashinfer_available():
from flashinfer.norm import ( from flashinfer.norm import (
fused_add_rmsnorm, fused_add_rmsnorm,
gemma_fused_add_rmsnorm, gemma_fused_add_rmsnorm,
...@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp): ...@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
return out return out
if is_hip(): if not is_flashinfer_available():
logger.info( logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries." "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
) )
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
...@@ -7,10 +7,9 @@ from torch import nn ...@@ -7,10 +7,9 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_hip from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later if is_flashinfer_available():
if not is_hip():
from flashinfer.sampling import ( from flashinfer.sampling import (
min_p_sampling_from_probs, min_p_sampling_from_probs,
top_k_renorm_prob, top_k_renorm_prob,
......
...@@ -25,13 +25,11 @@ import torch ...@@ -25,13 +25,11 @@ import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule from sglang.srt.utils import is_flashinfer_available, replace_submodule
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_flashinfer_available():
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import SegmentGEMMWrapper from flashinfer import SegmentGEMMWrapper
......
...@@ -47,10 +47,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -47,10 +47,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later if is_flashinfer_available():
if not is_hip():
from flashinfer import bmm_fp8 from flashinfer import bmm_fp8
......
...@@ -43,10 +43,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -43,10 +43,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later if is_flashinfer_available():
if not is_hip():
from flashinfer import bmm_fp8 from flashinfer import bmm_fp8
......
...@@ -22,7 +22,7 @@ import random ...@@ -22,7 +22,7 @@ import random
import tempfile import tempfile
from typing import List, Optional from typing import List, Optional
from sglang.srt.utils import is_hip, is_ipv6, is_port_available from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -151,8 +151,7 @@ class ServerArgs: ...@@ -151,8 +151,7 @@ class ServerArgs:
) )
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
# ROCm: flashinfer available later if not is_flashinfer_available():
if is_hip():
self.attention_backend = "triton" self.attention_backend = "triton"
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
......
...@@ -50,11 +50,19 @@ show_time_cost = False ...@@ -50,11 +50,19 @@ show_time_cost = False
time_infos = {} time_infos = {}
# torch flag AMD GPU
def is_hip() -> bool: def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None return torch.version.hip is not None
def is_flashinfer_available():
"""
Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
"""
return torch.cuda.is_available() and not is_hip()
def is_ipv6(address): def is_ipv6(address):
try: try:
ipaddress.IPv6Address(address) ipaddress.IPv6Address(address)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment