Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f0f8a769
Unverified
Commit
f0f8a769
authored
Oct 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 18, 2024
Browse files
Simplify the nan detection and greedy check in sampler (#1709)
parent
2bcfba1b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
24 additions
and
7 deletions
+24
-7
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+6
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+4
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+5
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-1
No files found.
python/sglang/srt/layers/sampler.py
View file @
f0f8a769
...
...
@@ -21,6 +21,10 @@ logger = logging.getLogger(__name__)
class
Sampler
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
use_nan_detectioin
=
not
global_server_args_dict
[
"disable_nan_detection"
]
def
forward
(
self
,
logits
:
Union
[
torch
.
Tensor
,
LogitsProcessorOutput
],
...
...
@@ -36,13 +40,13 @@ class Sampler(nn.Module):
logits
=
None
del
logits
if
torch
.
any
(
torch
.
isnan
(
probs
)):
if
self
.
use_nan_detectioin
and
torch
.
any
(
torch
.
isnan
(
probs
)):
logger
.
warning
(
"Detected errors during sampling! NaN in the probability."
)
probs
=
torch
.
where
(
torch
.
isnan
(
probs
),
torch
.
full_like
(
probs
,
1e-10
),
probs
)
if
sampling_info
.
top_ks
.
max
().
item
()
<=
1
:
if
sampling_info
.
is_all_greedy
:
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids
=
torch
.
argmax
(
probs
,
-
1
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f0f8a769
...
...
@@ -53,6 +53,7 @@ global_server_args_dict = {
"triton_attention_reduce_in_fp32"
:
ServerArgs
.
triton_attention_reduce_in_fp32
,
"disable_mla"
:
ServerArgs
.
disable_mla
,
"torchao_config"
:
ServerArgs
.
torchao_config
,
"disable_nan_detection"
:
ServerArgs
.
disable_nan_detection
,
}
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
f0f8a769
...
...
@@ -245,10 +245,10 @@ class CudaGraphRunner:
self
.
out_cache_loc
.
zero_
()
# Common inputs
self
.
input_ids
[:
raw_bs
]
=
forward_batch
.
input_ids
self
.
req_pool_indices
[:
raw_bs
]
=
forward_batch
.
req_pool_indices
self
.
seq_lens
[:
raw_bs
]
=
forward_batch
.
seq_lens
self
.
out_cache_loc
[:
raw_bs
]
=
forward_batch
.
out_cache_loc
self
.
input_ids
[:
raw_bs
]
.
copy_
(
forward_batch
.
input_ids
)
self
.
req_pool_indices
[:
raw_bs
]
.
copy_
(
forward_batch
.
req_pool_indices
)
self
.
seq_lens
[:
raw_bs
]
.
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_bs
]
.
copy_
(
forward_batch
.
out_cache_loc
)
# Attention backend
self
.
model_runner
.
attn_backend
.
init_forward_metadata_replay_cuda_graph
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f0f8a769
...
...
@@ -137,6 +137,7 @@ class ModelRunner:
"disable_mla"
:
server_args
.
disable_mla
,
"torchao_config"
:
server_args
.
torchao_config
,
"disable_penalizer"
:
server_args
.
disable_penalizer
,
"disable_nan_detection"
:
server_args
.
disable_nan_detection
,
}
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
f0f8a769
...
...
@@ -20,6 +20,9 @@ class SamplingBatchInfo:
top_ks
:
torch
.
Tensor
min_ps
:
torch
.
Tensor
# All requests use greedy sampling
is_all_greedy
:
bool
# Dispatch in CUDA graph
need_min_p_sampling
:
bool
...
...
@@ -73,6 +76,7 @@ class SamplingBatchInfo:
top_ks
=
top_ks
,
min_ps
=
min_ps
,
need_min_p_sampling
=
any
(
r
.
sampling_params
.
min_p
>
0
for
r
in
reqs
),
is_all_greedy
=
top_ks
.
max
().
item
()
<=
1
,
vocab_size
=
vocab_size
,
device
=
batch
.
input_ids
.
device
,
)
...
...
@@ -204,6 +208,7 @@ class SamplingBatchInfo:
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
self
.
is_all_greedy
=
self
.
is_all_greedy
and
other
.
is_all_greedy
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
)
...
...
python/sglang/srt/server_args.py
View file @
f0f8a769
...
...
@@ -114,6 +114,7 @@ class ServerArgs:
disable_custom_all_reduce
:
bool
=
False
disable_mla
:
bool
=
False
disable_penalizer
:
bool
=
False
disable_nan_detection
:
bool
=
False
enable_overlap_schedule
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
...
...
@@ -577,7 +578,12 @@ class ServerArgs:
parser
.
add_argument
(
"--disable-penalizer"
,
action
=
"store_true"
,
help
=
"Disable the logit penalizer (e.g., frequency and repetition penalty)."
,
help
=
"Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests."
,
)
parser
.
add_argument
(
"--disable-nan-detection"
,
action
=
"store_true"
,
help
=
"Disable the NaN detection for better performance."
,
)
parser
.
add_argument
(
"--enable-overlap-schedule"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment