"src/array/cuda/csr_mm.hip" did not exist on "619d735df5dc2a62eca5a00e11e4290407169cb1"
Unverified Commit 752e6430 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Allow disabling flashinfer sampling kernel (#778)

parent 30db99b3
...@@ -7,8 +7,11 @@ from torch import nn ...@@ -7,8 +7,11 @@ from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata from sglang.srt.managers.controller.model_runner import (
from sglang.srt.server import global_server_args_dict ForwardMode,
InputMetadata,
global_server_args_dict,
)
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.server import global_server_args_dict from sglang.srt.managers.controller.infer_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
......
...@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool ...@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
}
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case. # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
...@@ -687,7 +694,7 @@ class Batch: ...@@ -687,7 +694,7 @@ class Batch:
# TODO(lmzheng): apply penalty # TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
if True: if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
......
...@@ -25,7 +25,12 @@ from vllm.distributed import ( ...@@ -25,7 +25,12 @@ from vllm.distributed import (
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata from sglang.srt.managers.controller.infer_batch import (
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -60,7 +65,13 @@ class ModelRunner: ...@@ -60,7 +65,13 @@ class ModelRunner:
self.nccl_port = nccl_port self.nccl_port = nccl_port
self.server_args = server_args self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(self.model_config) self.is_multimodal_model = is_multimodal_model(self.model_config)
monkey_patch_vllm_dummy_weight_loader() global_server_args_dict.update(
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
)
# Init torch distributed # Init torch distributed
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
...@@ -108,6 +119,7 @@ class ModelRunner: ...@@ -108,6 +119,7 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
monkey_patch_vllm_dummy_weight_loader()
device_config = DeviceConfig() device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format) load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig( vllm_model_config = VllmModelConfig(
......
...@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) ...@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app = FastAPI() app = FastAPI()
tokenizer_manager = None tokenizer_manager = None
# Put some args for easily access
global_server_args_dict = {}
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
...@@ -150,14 +147,6 @@ def available_models(): ...@@ -150,14 +147,6 @@ def available_models():
return ModelList(data=model_cards) return ModelList(data=model_cards)
def _set_global_server_args(server_args: ServerArgs):
global global_server_args_dict
global_server_args_dict = {
"disable_flashinfer": server_args.disable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
def _set_torch_compile_config(): def _set_torch_compile_config():
# The following configurations are for torch compile optimizations # The following configurations are for torch compile optimizations
import torch._dynamo.config import torch._dynamo.config
...@@ -213,8 +202,6 @@ def launch_server( ...@@ -213,8 +202,6 @@ def launch_server(
if server_args.enable_torch_compile: if server_args.enable_torch_compile:
_set_torch_compile_config() _set_torch_compile_config()
_set_global_server_args(server_args)
# Allocate ports # Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port, server_args.port,
......
...@@ -52,13 +52,14 @@ class ServerArgs: ...@@ -52,13 +52,14 @@ class ServerArgs:
# Optimization/debug options # Optimization/debug options
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False efficient_weight_load: bool = False
# Distributed args # Distributed args
...@@ -303,7 +304,12 @@ class ServerArgs: ...@@ -303,7 +304,12 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--disable-flashinfer", "--disable-flashinfer",
action="store_true", action="store_true",
help="Disable flashinfer inference kernels.", help="Disable flashinfer attention kernels.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action="store_true",
help="Disable flashinfer sampling kernels.",
) )
parser.add_argument( parser.add_argument(
"--disable-radix-cache", "--disable-radix-cache",
...@@ -331,15 +337,15 @@ class ServerArgs: ...@@ -331,15 +337,15 @@ class ServerArgs:
help="Optimize the model with torch.compile, experimental feature.", help="Optimize the model with torch.compile, experimental feature.",
) )
parser.add_argument( parser.add_argument(
"--attention-reduce-in-fp32", "--enable-p2p-check",
action="store_true", action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
"This only affects Triton attention kernels",
) )
parser.add_argument( parser.add_argument(
"--enable-p2p-check", "--attention-reduce-in-fp32",
action="store_true", action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels",
) )
parser.add_argument( parser.add_argument(
"--efficient-weight-load", "--efficient-weight-load",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment