Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
af4e7910
Unverified
Commit
af4e7910
authored
Jul 12, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 12, 2024
Browse files
Clean up the usage of flashinfer (#610)
parent
519e20cf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
46 additions
and
75 deletions
+46
-75
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+3
-12
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+1
-7
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+36
-48
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+0
-8
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/layers/radix_attention.py
View file @
af4e7910
...
@@ -31,21 +31,13 @@ class RadixAttention(nn.Module):
...
@@ -31,21 +31,13 @@ class RadixAttention(nn.Module):
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
extend_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
# flashinfer now accepts float logit_cap argument
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
else
:
else
:
self
.
prefill_forward
=
self
.
prefill_forward_triton
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
extend_forward
=
self
.
extend_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
decode_forward
=
self
.
decode_forward_triton
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
else
0
def
prefill_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
self
.
logit_cap
=
logit_cap
if
logit_cap
is
not
None
and
logit_cap
>
0
else
0
# In SGLang, we call both the typical "prefill" and "prefill with cache" as "extend".
# See the extend_forward_xxx functions.
raise
NotImplementedError
()
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend_forward_triton
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
...
@@ -86,7 +78,6 @@ class RadixAttention(nn.Module):
...
@@ -86,7 +78,6 @@ class RadixAttention(nn.Module):
input_metadata
.
start_loc
,
input_metadata
.
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
max_seq_len
,
input_metadata
.
other_kv_index
,
input_metadata
.
total_num_tokens
,
input_metadata
.
total_num_tokens
,
sm_scale
=
self
.
scaling
,
sm_scale
=
self
.
scaling
,
logit_cap
=
self
.
logit_cap
,
logit_cap
=
self
.
logit_cap
,
...
@@ -94,7 +85,7 @@ class RadixAttention(nn.Module):
...
@@ -94,7 +85,7 @@ class RadixAttention(nn.Module):
return
o
return
o
def
prefill
_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
def
extend
_forward_flashinfer
(
self
,
q
,
k
,
v
,
input_metadata
:
InputMetadata
):
o1
,
s1
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
o1
,
s1
=
input_metadata
.
flashinfer_prefill_wrapper_ragged
.
forward_return_lse
(
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
q
.
contiguous
().
view
(
-
1
,
self
.
tp_q_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
k
.
contiguous
().
view
(
-
1
,
self
.
tp_k_head_num
,
self
.
head_dim
),
...
...
python/sglang/srt/layers/token_attention.py
View file @
af4e7910
...
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
...
@@ -107,7 +107,6 @@ def _fwd_kernel_stage2(
stride_obs
,
stride_obs
,
stride_oh
,
stride_oh
,
stride_req_to_token_b
,
stride_req_to_token_b
,
other_kv_index
,
# To fix a NAN issue
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
...
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
...
@@ -138,7 +137,7 @@ def _fwd_kernel_stage2(
+
cur_batch_req_idx
*
stride_req_to_token_b
+
cur_batch_req_idx
*
stride_req_to_token_b
+
(
start_n
+
offs_n
),
+
(
start_n
+
offs_n
),
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_seq_len
,
other
=
other_kv_index
,
other
=
0
,
)
)
qk
=
tl
.
load
(
qk
=
tl
.
load
(
...
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
...
@@ -250,7 +249,6 @@ def _token_softmax_reducev_fwd(
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
other_kv_index
,
):
):
BLOCK
=
64
BLOCK
=
64
batch
,
head
=
b_seq_len
.
shape
[
0
],
logics
.
shape
[
0
]
batch
,
head
=
b_seq_len
.
shape
[
0
],
logics
.
shape
[
0
]
...
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
...
@@ -277,7 +275,6 @@ def _token_softmax_reducev_fwd(
o
.
stride
(
0
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
req_to_tokens
.
stride
(
0
),
other_kv_index
,
)
)
return
return
...
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
...
@@ -295,7 +292,6 @@ def _token_softmax_reducev_fwd(
o
.
stride
(
0
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
1
),
req_to_tokens
.
stride
(
0
),
req_to_tokens
.
stride
(
0
),
other_kv_index
,
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
v_buffer
.
shape
[
-
1
],
BLOCK_DMODEL
=
v_buffer
.
shape
[
-
1
],
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
...
@@ -315,7 +311,6 @@ def token_attention_fwd(
...
@@ -315,7 +311,6 @@ def token_attention_fwd(
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
other_kv_index
,
total_num_tokens
,
total_num_tokens
,
sm_scale
=
None
,
sm_scale
=
None
,
logit_cap
=-
1
,
logit_cap
=-
1
,
...
@@ -347,5 +342,4 @@ def token_attention_fwd(
...
@@ -347,5 +342,4 @@ def token_attention_fwd(
b_req_idx
,
b_req_idx
,
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
other_kv_index
,
)
)
python/sglang/srt/managers/controller/infer_batch.py
View file @
af4e7910
...
@@ -729,7 +729,6 @@ class InputMetadata:
...
@@ -729,7 +729,6 @@ class InputMetadata:
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_start
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
out_cache_cont_end
:
torch
.
Tensor
=
None
other_kv_index
:
torch
.
Tensor
=
None
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
top_logprobs_nums
:
List
[
int
]
=
None
...
@@ -743,24 +742,19 @@ class InputMetadata:
...
@@ -743,24 +742,19 @@ class InputMetadata:
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
flashinfer_decode_wrapper
:
"BatchDecodeWithPagedKVCacheWrapper"
=
None
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
def
init_flashinfer_args
(
self
,
num_qo_heads
,
num_kv_heads
,
head_dim
):
if
(
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
forward_mode
==
ForwardMode
.
EXTEND
paged_kernel_lens
=
self
.
seq_lens
)
:
else
:
paged_kernel_lens
=
self
.
prefix_lens
paged_kernel_lens
=
self
.
prefix_lens
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
self
.
no_prefix
=
torch
.
all
(
self
.
prefix_lens
==
0
)
else
:
paged_kernel_lens
=
self
.
seq_lens
self
.
kv_indptr
=
torch
.
zeros
(
kv_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
self
.
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
kv_indptr
[
1
:]
=
torch
.
cumsum
(
paged_kernel_lens
,
dim
=
0
)
self
.
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
req_pool_indices_cpu
=
self
.
req_pool_indices
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
paged_kernel_lens_cpu
=
paged_kernel_lens
.
cpu
().
numpy
()
self
.
kv_indices
=
torch
.
cat
(
kv_indices
=
torch
.
cat
(
[
[
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
req_pool_indices_cpu
[
i
],
:
paged_kernel_lens_cpu
[
i
]
...
@@ -769,18 +763,34 @@ class InputMetadata:
...
@@ -769,18 +763,34 @@ class InputMetadata:
],
],
dim
=
0
,
dim
=
0
,
).
contiguous
()
).
contiguous
()
kv_last_page_len
=
torch
.
ones
(
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
if
self
.
forward_mode
==
ForwardMode
.
EXTEND
:
if
self
.
forward_mode
==
ForwardMode
.
DECODE
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
else
:
# extend part
# extend part
self
.
qo_indptr
=
torch
.
zeros
(
qo_indptr
=
torch
.
zeros
(
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
flashinfer_prefill_wrapper_ragged
.
begin_forward
(
self
.
qo_indptr
,
qo_indptr
,
self
.
qo_indptr
.
clone
()
,
qo_indptr
,
num_qo_heads
,
num_qo_heads
,
num_kv_heads
,
num_kv_heads
,
head_dim
,
head_dim
,
...
@@ -789,28 +799,15 @@ class InputMetadata:
...
@@ -789,28 +799,15 @@ class InputMetadata:
# cached part
# cached part
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
end_forward
()
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
flashinfer_prefill_wrapper_paged
.
begin_forward
(
self
.
qo_indptr
,
qo_indptr
,
self
.
kv_indptr
,
kv_indptr
,
self
.
kv_indices
,
kv_indices
,
self
.
kv_last_page_len
,
kv_last_page_len
,
num_qo_heads
,
num_qo_heads
,
num_kv_heads
,
num_kv_heads
,
head_dim
,
head_dim
,
1
,
1
,
)
)
else
:
self
.
flashinfer_decode_wrapper
.
end_forward
()
self
.
flashinfer_decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_last_page_len
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
1
,
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
token_to_kv_pool
.
kv_data
[
0
].
dtype
,
)
def
init_extend_args
(
self
):
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
...
@@ -822,7 +819,6 @@ class InputMetadata:
...
@@ -822,7 +819,6 @@ class InputMetadata:
def
create
(
def
create
(
cls
,
cls
,
model_runner
,
model_runner
,
tp_size
,
forward_mode
,
forward_mode
,
req_pool_indices
,
req_pool_indices
,
seq_lens
,
seq_lens
,
...
@@ -833,9 +829,6 @@ class InputMetadata:
...
@@ -833,9 +829,6 @@ class InputMetadata:
out_cache_cont_end
=
None
,
out_cache_cont_end
=
None
,
top_logprobs_nums
=
None
,
top_logprobs_nums
=
None
,
return_logprob
=
False
,
return_logprob
=
False
,
flashinfer_prefill_wrapper_ragged
=
None
,
flashinfer_prefill_wrapper_paged
=
None
,
flashinfer_decode_wrapper
=
None
,
):
):
batch_size
=
len
(
req_pool_indices
)
batch_size
=
len
(
req_pool_indices
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
start_loc
=
torch
.
zeros
((
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
@@ -845,9 +838,6 @@ class InputMetadata:
...
@@ -845,9 +838,6 @@ class InputMetadata:
if
forward_mode
==
ForwardMode
.
DECODE
:
if
forward_mode
==
ForwardMode
.
DECODE
:
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
positions
=
((
seq_lens
-
1
)
+
position_ids_offsets
).
to
(
torch
.
int64
)
other_kv_index
=
model_runner
.
req_to_token_pool
.
req_to_token
[
req_pool_indices
[
0
],
seq_lens
[
0
]
-
1
].
item
()
else
:
else
:
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
seq_lens_cpu
=
seq_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
prefix_lens_cpu
=
prefix_lens
.
cpu
().
numpy
()
...
@@ -865,7 +855,6 @@ class InputMetadata:
...
@@ -865,7 +855,6 @@ class InputMetadata:
),
),
device
=
"cuda"
,
device
=
"cuda"
,
)
)
other_kv_index
=
None
ret
=
cls
(
ret
=
cls
(
forward_mode
=
forward_mode
,
forward_mode
=
forward_mode
,
...
@@ -882,12 +871,11 @@ class InputMetadata:
...
@@ -882,12 +871,11 @@ class InputMetadata:
out_cache_loc
=
out_cache_loc
,
out_cache_loc
=
out_cache_loc
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_start
=
out_cache_cont_start
,
out_cache_cont_end
=
out_cache_cont_end
,
out_cache_cont_end
=
out_cache_cont_end
,
other_kv_index
=
other_kv_index
,
return_logprob
=
return_logprob
,
return_logprob
=
return_logprob
,
top_logprobs_nums
=
top_logprobs_nums
,
top_logprobs_nums
=
top_logprobs_nums
,
flashinfer_prefill_wrapper_ragged
=
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_ragged
=
model_runner
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
flashinfer_prefill_wrapper_paged
,
flashinfer_prefill_wrapper_paged
=
model_runner
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
flashinfer_decode_wrapper
,
flashinfer_decode_wrapper
=
model_runner
.
flashinfer_decode_wrapper
,
)
)
if
forward_mode
==
ForwardMode
.
EXTEND
:
if
forward_mode
==
ForwardMode
.
EXTEND
:
...
@@ -895,8 +883,8 @@ class InputMetadata:
...
@@ -895,8 +883,8 @@ class InputMetadata:
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
if
not
global_server_args_dict
.
get
(
"disable_flashinfer"
,
False
):
ret
.
init_flashinfer_args
(
ret
.
init_flashinfer_args
(
model_runner
.
model_config
.
num_attention_heads
//
tp_size
,
model_runner
.
model_config
.
num_attention_heads
//
model_runner
.
tp_size
,
model_runner
.
model_config
.
get_num_kv_heads
(
tp_size
),
model_runner
.
model_config
.
get_num_kv_heads
(
model_runner
.
tp_size
),
model_runner
.
model_config
.
head_dim
,
model_runner
.
model_config
.
head_dim
,
)
)
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
af4e7910
...
@@ -221,7 +221,6 @@ class ModelRunner:
...
@@ -221,7 +221,6 @@ class ModelRunner:
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
forward_mode
=
ForwardMode
.
EXTEND
,
forward_mode
=
ForwardMode
.
EXTEND
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
prefix_lens
=
batch
.
prefix_lens
,
...
@@ -229,9 +228,6 @@ class ModelRunner:
...
@@ -229,9 +228,6 @@ class ModelRunner:
out_cache_loc
=
batch
.
out_cache_loc
,
out_cache_loc
=
batch
.
out_cache_loc
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
@@ -242,7 +238,6 @@ class ModelRunner:
...
@@ -242,7 +238,6 @@ class ModelRunner:
input_metadata
=
InputMetadata
.
create
(
input_metadata
=
InputMetadata
.
create
(
self
,
self
,
forward_mode
=
ForwardMode
.
DECODE
,
forward_mode
=
ForwardMode
.
DECODE
,
tp_size
=
self
.
tp_size
,
req_pool_indices
=
batch
.
req_pool_indices
,
req_pool_indices
=
batch
.
req_pool_indices
,
seq_lens
=
batch
.
seq_lens
,
seq_lens
=
batch
.
seq_lens
,
prefix_lens
=
batch
.
prefix_lens
,
prefix_lens
=
batch
.
prefix_lens
,
...
@@ -252,9 +247,6 @@ class ModelRunner:
...
@@ -252,9 +247,6 @@ class ModelRunner:
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
out_cache_cont_end
=
batch
.
out_cache_cont_end
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
flashinfer_prefill_wrapper_ragged
=
self
.
flashinfer_prefill_wrapper_ragged
,
flashinfer_prefill_wrapper_paged
=
self
.
flashinfer_prefill_wrapper_paged
,
flashinfer_decode_wrapper
=
self
.
flashinfer_decode_wrapper
,
)
)
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
...
...
python/sglang/srt/server_args.py
View file @
af4e7910
...
@@ -53,6 +53,7 @@ class ServerArgs:
...
@@ -53,6 +53,7 @@ class ServerArgs:
disable_flashinfer
:
bool
=
False
disable_flashinfer
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
...
@@ -294,6 +295,11 @@ class ServerArgs:
...
@@ -294,6 +295,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable regex jump-forward"
,
help
=
"Disable regex jump-forward"
,
)
)
parser
.
add_argument
(
"--disable-cuda-graph"
,
action
=
"store_true"
,
help
=
"Disable cuda graph."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-disk-cache"
,
"--disable-disk-cache"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment