"vscode:/vscode.git/clone" did not exist on "e3936d4fb37cc0cd3a7cd9ffb58f357c5f417fff"
Unverified Commit 6a5b352a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use is_flashinfer_available to replace is_hip for flashinfer check (#1596)


Co-authored-by: default avatarZhang Liangang <liangang.zhang@intel.com>
parent 565b05f0
......@@ -20,9 +20,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
if not is_hip():
if is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import (
......@@ -146,8 +146,8 @@ def get_act_fn(
return act_fn
if is_hip():
if not is_flashinfer_available():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
"FlashInfer is not available on Non-NV GPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
......@@ -19,13 +19,12 @@ from sglang.srt.layers.attention.flashinfer_utils import (
update_flashinfer_indices,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
......
......@@ -21,9 +21,9 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
if not is_hip():
if is_flashinfer_available():
from flashinfer.norm import (
fused_add_rmsnorm,
gemma_fused_add_rmsnorm,
......@@ -119,8 +119,8 @@ class GemmaRMSNorm(CustomOp):
return out
if is_hip():
if not is_flashinfer_available():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
"FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
......@@ -7,10 +7,9 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
......
......@@ -25,13 +25,11 @@ import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip, replace_submodule
from sglang.srt.utils import is_flashinfer_available, replace_submodule
logger = logging.getLogger(__name__)
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
......
......@@ -47,10 +47,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import bmm_fp8
......
......@@ -43,10 +43,9 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_flashinfer_available
# ROCm: flashinfer available later
if not is_hip():
if is_flashinfer_available():
from flashinfer import bmm_fp8
......
......@@ -22,7 +22,7 @@ import random
import tempfile
from typing import List, Optional
from sglang.srt.utils import is_hip, is_ipv6, is_port_available
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
logger = logging.getLogger(__name__)
......@@ -151,8 +151,7 @@ class ServerArgs:
)
self.sampling_backend = "pytorch"
# ROCm: flashinfer available later
if is_hip():
if not is_flashinfer_available():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"
......
......@@ -50,11 +50,19 @@ show_time_cost = False
time_infos = {}
# torch flag AMD GPU
def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None
def is_flashinfer_available():
"""
Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
"""
return torch.cuda.is_available() and not is_hip()
def is_ipv6(address):
try:
ipaddress.IPv6Address(address)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment