Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
f424e76d
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "f364daf0d2df6f96855c0ac3f43ac3a1d5f2b0d6"
Unverified
Commit
f424e76d
authored
Jul 20, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 20, 2024
Browse files
Fix illegal tokens during sampling (#676)
parent
490a1f39
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
15 deletions
+18
-15
docs/sampling_params.md
docs/sampling_params.md
+1
-1
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+14
-10
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+2
-3
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
No files found.
docs/sampling_params.md
View file @
f424e76d
...
@@ -7,7 +7,7 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
...
@@ -7,7 +7,7 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
@
dataclass
@
dataclass
class
GenerateReqInput
:
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Union
[
List
[
str
],
str
]
text
:
Optional
[
Union
[
List
[
str
],
str
]
]
=
None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The image input. It can be a file name, a url, or base64 encoded string.
# The image input. It can be a file name, a url, or base64 encoded string.
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
f424e76d
...
@@ -665,16 +665,20 @@ class Batch:
...
@@ -665,16 +665,20 @@ class Batch:
# TODO(lmzheng): apply penalty
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
try
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
uniform_samples
=
torch
.
rand
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
batch_next_token_ids
,
_
=
top_k_top_p_sampling_from_probs
(
)
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
batch_next_token_ids
,
_
=
top_k_top_p_sampling_from_probs
(
)
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
)
# FIXME: this is a temporary fix for the illegal token ids
except
RuntimeError
as
e
:
illegal_mask
=
torch
.
logical_or
(
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
batch_next_token_ids
<
0
,
batch_next_token_ids
>=
probs
.
shape
[
-
1
]
)
if
torch
.
any
(
illegal_mask
):
warnings
.
warn
(
"Illegal sampled token ids"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
batch_next_token_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
if
has_regex
:
if
has_regex
:
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
f424e76d
...
@@ -246,12 +246,11 @@ class ModelRunner:
...
@@ -246,12 +246,11 @@ class ModelRunner:
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
)
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
)
)
)
logger
.
info
(
f
"Capture for batch sizes
{
batch_size_list
}
"
)
try
:
try
:
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
self
.
cuda_graph_runner
.
capture
(
batch_size_list
)
except
:
except
RuntimeError
as
e
:
raise
Exception
(
raise
Exception
(
f
"Capture cuda graph failed. Possible solutions:
\n
"
f
"Capture cuda graph failed
{
e
}
. Possible solutions:
\n
"
f
"1. disable cuda graph by --disable-cuda-graph
\n
"
f
"1. disable cuda graph by --disable-cuda-graph
\n
"
f
"2. set --mem-fraction-static to a smaller value
\n
"
f
"2. set --mem-fraction-static to a smaller value
\n
"
f
"Open an issue on GitHub with reproducible scripts if you need help.
\n
"
f
"Open an issue on GitHub with reproducible scripts if you need help.
\n
"
...
...
python/sglang/srt/managers/io_struct.py
View file @
f424e76d
...
@@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams
...
@@ -14,7 +14,7 @@ from sglang.srt.sampling_params import SamplingParams
@
dataclass
@
dataclass
class
GenerateReqInput
:
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Union
[
List
[
str
],
str
]
text
:
Optional
[
Union
[
List
[
str
],
str
]
]
=
None
# The token ids for text; one can either specify text or input_ids.
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The image input. It can be a file name, a url, or base64 encoded string.
# The image input. It can be a file name, a url, or base64 encoded string.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment