Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f424e76d
Unverified
Commit
f424e76d
authored
Jul 20, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 20, 2024
Browse files
Fix illegal tokens during sampling (#676)
parent
490a1f39
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
15 deletions
+18
-15
docs/sampling_params.md
docs/sampling_params.md
+1
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+14
-10
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+2
-3
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
No files found.
docs/sampling_params.md
View file @
f424e76d
...
...
@@ -7,7 +7,7 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
@
dataclass
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Union
[
List
[
str
],
str
]
text
:
Optional
[
Union
[
List
[
str
],
str
]
]
=
None
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The image input. It can be a file name, a url, or base64 encoded string.
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
f424e76d
...
...
@@ -665,16 +665,20 @@ class Batch:
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
try
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
uniform_samples
=
torch
.
rand
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
batch_next_token_ids
,
_
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
)
except
RuntimeError
as
e
:
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
# FIXME: this is a temporary fix for the illegal token ids
illegal_mask
=
torch
.
logical_or
(
batch_next_token_ids
<
0
,
batch_next_token_ids
>=
probs
.
shape
[
-
1
]
)
if
torch
.
any
(
illegal_mask
):
warnings
.
warn
(
"Illegal sampled token ids"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
batch_next_token_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
if
has_regex
:
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
f424e76d
...
...
@@ -246,12 +246,11 @@ class ModelRunner:
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
)
)
logger
.
info
(
f
"Capture for batch sizes
{
batch_size_list
}
"
)
try
:
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
except
:
except
RuntimeError
as
e
:
raise
Exception
(
f
"Capture cuda graph failed. Possible solutions:
\n
"
f
"Capture cuda graph failed
{
e
}
. Possible solutions:
\n
"
f
"1. disable cuda graph by --disable-cuda-graph
\n
"
f
"2. set --mem-fraction-static to a smaller value
\n
"
f
"Open an issue on GitHub with reproducible scripts if you need help.
\n
"
...
...
python/sglang/srt/managers/io_struct.py
View file @
f424e76d
...
...
@@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams
@
dataclass
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Union
[
List
[
str
],
str
]
text
:
Optional
[
Union
[
List
[
str
],
str
]
]
=
None
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The image input. It can be a file name, a url, or base64 encoded string.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment