Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
752e6430
Unverified
Commit
752e6430
authored
Jul 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 27, 2024
Browse files
Allow disabling flashinfer sampling kernel (#778)
parent
30db99b3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
26 deletions
+41
-26
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+5
-2
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+8
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+14
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-13
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-7
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
752e6430
...
@@ -7,8 +7,11 @@ from torch import nn
...
@@ -7,8 +7,11 @@ from torch import nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
(
from
sglang.srt.server
import
global_server_args_dict
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
...
...
python/sglang/srt/layers/token_attention.py
View file @
752e6430
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.
server
import
global_server_args_dict
from
sglang.srt.
managers.controller.infer_batch
import
global_server_args_dict
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TRITON_TYPE
=
tl
.
float32
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
752e6430
...
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
...
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
global_server_args_dict
=
{
"disable_flashinfer"
:
False
,
"disable_flashinfer_sampling"
:
False
,
"attention_reduce_in_fp32"
:
False
,
}
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
...
@@ -687,7 +694,7 @@ class Batch:
...
@@ -687,7 +694,7 @@ class Batch:
# TODO(lmzheng): apply penalty
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
True
:
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
752e6430
...
@@ -25,7 +25,12 @@ from vllm.distributed import (
...
@@ -25,7 +25,12 @@ from vllm.distributed import (
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -60,7 +65,13 @@ class ModelRunner:
...
@@ -60,7 +65,13 @@ class ModelRunner:
self
.
nccl_port
=
nccl_port
self
.
nccl_port
=
nccl_port
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
monkey_patch_vllm_dummy_weight_loader
()
global_server_args_dict
.
update
(
{
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"disable_flashinfer_sampling"
:
server_args
.
disable_flashinfer_sampling
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
)
# Init torch distributed
# Init torch distributed
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
...
@@ -108,6 +119,7 @@ class ModelRunner:
...
@@ -108,6 +119,7 @@ class ModelRunner:
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
monkey_patch_vllm_dummy_weight_loader
()
device_config
=
DeviceConfig
()
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
vllm_model_config
=
VllmModelConfig
(
vllm_model_config
=
VllmModelConfig
(
...
...
python/sglang/srt/server.py
View file @
752e6430
...
@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...
@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app
=
FastAPI
()
app
=
FastAPI
()
tokenizer_manager
=
None
tokenizer_manager
=
None
# Put some args for easily access
global_server_args_dict
=
{}
@
app
.
get
(
"/health"
)
@
app
.
get
(
"/health"
)
async
def
health
()
->
Response
:
async
def
health
()
->
Response
:
...
@@ -150,14 +147,6 @@ def available_models():
...
@@ -150,14 +147,6 @@ def available_models():
return
ModelList
(
data
=
model_cards
)
return
ModelList
(
data
=
model_cards
)
def
_set_global_server_args
(
server_args
:
ServerArgs
):
global
global_server_args_dict
global_server_args_dict
=
{
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
def
_set_torch_compile_config
():
def
_set_torch_compile_config
():
# The following configurations are for torch compile optimizations
# The following configurations are for torch compile optimizations
import
torch._dynamo.config
import
torch._dynamo.config
...
@@ -213,8 +202,6 @@ def launch_server(
...
@@ -213,8 +202,6 @@ def launch_server(
if
server_args
.
enable_torch_compile
:
if
server_args
.
enable_torch_compile
:
_set_torch_compile_config
()
_set_torch_compile_config
()
_set_global_server_args
(
server_args
)
# Allocate ports
# Allocate ports
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
port
,
...
...
python/sglang/srt/server_args.py
View file @
752e6430
...
@@ -52,13 +52,14 @@ class ServerArgs:
...
@@ -52,13 +52,14 @@ class ServerArgs:
# Optimization/debug options
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
efficient_weight_load
:
bool
=
False
efficient_weight_load
:
bool
=
False
# Distributed args
# Distributed args
...
@@ -303,7 +304,12 @@ class ServerArgs:
...
@@ -303,7 +304,12 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-flashinfer"
,
"--disable-flashinfer"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable flashinfer inference kernels."
,
help
=
"Disable flashinfer attention kernels."
,
)
parser
.
add_argument
(
"--disable-flashinfer-sampling"
,
action
=
"store_true"
,
help
=
"Disable flashinfer sampling kernels."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-radix-cache"
,
"--disable-radix-cache"
,
...
@@ -331,15 +337,15 @@ class ServerArgs:
...
@@ -331,15 +337,15 @@ class ServerArgs:
help
=
"Optimize the model with torch.compile, experimental feature."
,
help
=
"Optimize the model with torch.compile, experimental feature."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
attention-reduce-in-fp32
"
,
"--
enable-p2p-check
"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Enable P2P check for GPU access, otherwise the p2p access is allowed by default."
,
"This only affects Triton attention kernels"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
enable-p2p-check
"
,
"--
attention-reduce-in-fp32
"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable P2P check for GPU access, otherwise the p2p access is allowed by default."
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--efficient-weight-load"
,
"--efficient-weight-load"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment