Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3de2f30a
"tests/vscode:/vscode.git/clone" did not exist on "5185c522b0acb798ed8ebb9084510bbe5b58e73b"
Unverified
Commit
3de2f30a
authored
Jul 17, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 17, 2024
Browse files
Flashinfer sample kernel (#617)
parent
4efcc59d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
30 deletions
+17
-30
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+12
-25
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
No files found.
python/sglang/bench_latency.py
View file @
3de2f30a
...
...
@@ -156,14 +156,14 @@ def extend(reqs, model_runner):
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
,
None
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
return
next_token_ids
,
output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
.
cpu
().
numpy
())
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
return
next_token_ids
,
output
.
next_token_logits
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
3de2f30a
...
...
@@ -7,6 +7,7 @@ from typing import List, Union
import
numpy
as
np
import
torch
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
...
...
@@ -398,10 +399,10 @@ class Batch:
).
view
(
-
1
,
1
)
self
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
.
view
(
-
1
,
1
)
)
self
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
.
view
(
-
1
,
1
)
)
self
.
frequency_penalties
=
torch
.
tensor
(
[
r
.
sampling_params
.
frequency_penalty
for
r
in
reqs
],
dtype
=
torch
.
float
,
...
...
@@ -659,20 +660,17 @@ class Batch:
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_sort
,
probs_idx
=
_top_p_top_k
(
probs
,
self
.
top_ps
,
self
.
top_ks
)
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
batch_next_token_ids
,
_
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
)
except
RuntimeError
as
e
:
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
batch_next_token_probs
=
torch
.
gather
(
probs_sort
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
batch_next_token_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
if
has_regex
:
batch_next_token_ids_cpu
=
batch_next_token_ids
.
cpu
().
numpy
()
...
...
@@ -682,18 +680,7 @@ class Batch:
req
.
regex_fsm_state
,
batch_next_token_ids_cpu
[
i
]
)
return
batch_next_token_ids
,
batch_next_token_probs
def
_top_p_top_k
(
probs
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
):
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
return
probs_sort
,
probs_idx
return
batch_next_token_ids
@
dataclass
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
3de2f30a
...
...
@@ -451,7 +451,7 @@ class ModelTpServer:
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -574,7 +574,7 @@ class ModelTpServer:
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
python/sglang/srt/server.py
View file @
3de2f30a
...
...
@@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
"flashinfer"
,
"0.
0.8
"
,
"0.
1.0
"
,
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment