Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
752e6430
"src/array/cuda/csr_mm.hip" did not exist on "619d735df5dc2a62eca5a00e11e4290407169cb1"
Unverified
Commit
752e6430
authored
Jul 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 27, 2024
Browse files
Allow disabling flashinfer sampling kernel (#778)
parent
30db99b3
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
26 deletions
+41
-26
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+5
-2
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+8
-1
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+14
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-13
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-7
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
752e6430
...
@@ -7,8 +7,11 @@ from torch import nn
...
@@ -7,8 +7,11 @@ from torch import nn
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.layers.token_attention
import
token_attention_fwd
from
sglang.srt.managers.controller.model_runner
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.model_runner
import
(
from
sglang.srt.server
import
global_server_args_dict
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
...
...
python/sglang/srt/layers/token_attention.py
View file @
752e6430
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.
server
import
global_server_args_dict
from
sglang.srt.
managers.controller.infer_batch
import
global_server_args_dict
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TRITON_TYPE
=
tl
.
float32
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
752e6430
...
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
...
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
# Put some global args for easy access
global_server_args_dict
=
{
"disable_flashinfer"
:
False
,
"disable_flashinfer_sampling"
:
False
,
"attention_reduce_in_fp32"
:
False
,
}
class
ForwardMode
(
IntEnum
):
class
ForwardMode
(
IntEnum
):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
...
@@ -687,7 +694,7 @@ class Batch:
...
@@ -687,7 +694,7 @@ class Batch:
# TODO(lmzheng): apply penalty
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
True
:
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
752e6430
...
@@ -25,7 +25,12 @@ from vllm.distributed import (
...
@@ -25,7 +25,12 @@ from vllm.distributed import (
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.managers.controller.infer_batch
import
Batch
,
ForwardMode
,
InputMetadata
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
ForwardMode
,
InputMetadata
,
global_server_args_dict
,
)
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.memory_pool
import
ReqToTokenPool
,
TokenToKVPool
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
...
@@ -60,7 +65,13 @@ class ModelRunner:
...
@@ -60,7 +65,13 @@ class ModelRunner:
self
.
nccl_port
=
nccl_port
self
.
nccl_port
=
nccl_port
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
monkey_patch_vllm_dummy_weight_loader
()
global_server_args_dict
.
update
(
{
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"disable_flashinfer_sampling"
:
server_args
.
disable_flashinfer_sampling
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
)
# Init torch distributed
# Init torch distributed
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
...
@@ -108,6 +119,7 @@ class ModelRunner:
...
@@ -108,6 +119,7 @@ class ModelRunner:
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
monkey_patch_vllm_dummy_weight_loader
()
device_config
=
DeviceConfig
()
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
vllm_model_config
=
VllmModelConfig
(
vllm_model_config
=
VllmModelConfig
(
...
...
python/sglang/srt/server.py
View file @
752e6430
...
@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...
@@ -65,9 +65,6 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
app
=
FastAPI
()
app
=
FastAPI
()
tokenizer_manager
=
None
tokenizer_manager
=
None
# Put some args for easily access
global_server_args_dict
=
{}
@
app
.
get
(
"/health"
)
@
app
.
get
(
"/health"
)
async
def
health
()
->
Response
:
async
def
health
()
->
Response
:
...
@@ -150,14 +147,6 @@ def available_models():
...
@@ -150,14 +147,6 @@ def available_models():
return
ModelList
(
data
=
model_cards
)
return
ModelList
(
data
=
model_cards
)
def
_set_global_server_args
(
server_args
:
ServerArgs
):
global
global_server_args_dict
global_server_args_dict
=
{
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
def
_set_torch_compile_config
():
def
_set_torch_compile_config
():
# The following configurations are for torch compile optimizations
# The following configurations are for torch compile optimizations
import
torch._dynamo.config
import
torch._dynamo.config
...
@@ -213,8 +202,6 @@ def launch_server(
...
@@ -213,8 +202,6 @@ def launch_server(
if
server_args
.
enable_torch_compile
:
if
server_args
.
enable_torch_compile
:
_set_torch_compile_config
()
_set_torch_compile_config
()
_set_global_server_args
(
server_args
)
# Allocate ports
# Allocate ports
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
additional_ports
=
allocate_init_ports
(
server_args
.
port
,
server_args
.
port
,
...
...
python/sglang/srt/server_args.py
View file @
752e6430
...
@@ -52,13 +52,14 @@ class ServerArgs:
...
@@ -52,13 +52,14 @@ class ServerArgs:
# Optimization/debug options
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
efficient_weight_load
:
bool
=
False
efficient_weight_load
:
bool
=
False
# Distributed args
# Distributed args
...
@@ -303,7 +304,12 @@ class ServerArgs:
...
@@ -303,7 +304,12 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-flashinfer"
,
"--disable-flashinfer"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable flashinfer inference kernels."
,
help
=
"Disable flashinfer attention kernels."
,
)
parser
.
add_argument
(
"--disable-flashinfer-sampling"
,
action
=
"store_true"
,
help
=
"Disable flashinfer sampling kernels."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-radix-cache"
,
"--disable-radix-cache"
,
...
@@ -331,15 +337,15 @@ class ServerArgs:
...
@@ -331,15 +337,15 @@ class ServerArgs:
help
=
"Optimize the model with torch.compile, experimental feature."
,
help
=
"Optimize the model with torch.compile, experimental feature."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
attention-reduce-in-fp32
"
,
"--
enable-p2p-check
"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Enable P2P check for GPU access, otherwise the p2p access is allowed by default."
,
"This only affects Triton attention kernels"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--
enable-p2p-check
"
,
"--
attention-reduce-in-fp32
"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable P2P check for GPU access, otherwise the p2p access is allowed by default."
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--efficient-weight-load"
,
"--efficient-weight-load"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment