Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3de2f30a
Unverified
Commit
3de2f30a
authored
Jul 17, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 17, 2024
Browse files
Flashinfer sample kernel (#617)
parent
4efcc59d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
30 deletions
+17
-30
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+2
-2
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+12
-25
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
No files found.
python/sglang/bench_latency.py
View file @
3de2f30a
...
@@ -156,14 +156,14 @@ def extend(reqs, model_runner):
...
@@ -156,14 +156,14 @@ def extend(reqs, model_runner):
)
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
,
None
)
batch
.
prepare_for_extend
(
model_runner
.
model_config
.
vocab_size
,
None
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
return
next_token_ids
,
output
.
next_token_logits
,
batch
return
next_token_ids
,
output
.
next_token_logits
,
batch
def
decode
(
input_token_ids
,
batch
,
model_runner
):
def
decode
(
input_token_ids
,
batch
,
model_runner
):
batch
.
prepare_for_decode
(
input_token_ids
.
cpu
().
numpy
())
batch
.
prepare_for_decode
(
input_token_ids
.
cpu
().
numpy
())
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
output
=
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
return
next_token_ids
,
output
.
next_token_logits
return
next_token_ids
,
output
.
next_token_logits
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
3de2f30a
...
@@ -7,6 +7,7 @@ from typing import List, Union
...
@@ -7,6 +7,7 @@ from typing import List, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
flashinfer.sampling
import
top_k_top_p_sampling_from_probs
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
...
@@ -398,10 +399,10 @@ class Batch:
...
@@ -398,10 +399,10 @@ class Batch:
).
view
(
-
1
,
1
)
).
view
(
-
1
,
1
)
self
.
top_ps
=
torch
.
tensor
(
self
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
.
view
(
-
1
,
1
)
)
self
.
top_ks
=
torch
.
tensor
(
self
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
.
view
(
-
1
,
1
)
)
self
.
frequency_penalties
=
torch
.
tensor
(
self
.
frequency_penalties
=
torch
.
tensor
(
[
r
.
sampling_params
.
frequency_penalty
for
r
in
reqs
],
[
r
.
sampling_params
.
frequency_penalty
for
r
in
reqs
],
dtype
=
torch
.
float
,
dtype
=
torch
.
float
,
...
@@ -659,20 +660,17 @@ class Batch:
...
@@ -659,20 +660,17 @@ class Batch:
# TODO(lmzheng): apply penalty
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
probs_sort
,
probs_idx
=
_top_p_top_k
(
probs
,
self
.
top_ps
,
self
.
top_ks
)
try
:
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
except
RuntimeError
as
e
:
uniform_samples
=
torch
.
rand
(
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
sampled_index
=
torch
.
ones
(
probs_sort
.
shape
[:
-
1
]
+
(
1
,),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
)
batch_next_token_ids
=
to
rch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
batch_next_token_ids
,
_
=
to
p_k_top_p_sampling_from_probs
(
-
1
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
)
)
batch_next_token_probs
=
torch
.
gather
(
except
RuntimeError
as
e
:
probs_sort
,
dim
=
1
,
index
=
sampled_index
warnings
.
warn
(
f
"Ignore errors in sampling:
{
e
}
"
)
).
view
(
-
1
)
batch_next_token_ids
=
torch
.
argmax
(
probs
,
dim
=
-
1
)
if
has_regex
:
if
has_regex
:
batch_next_token_ids_cpu
=
batch_next_token_ids
.
cpu
().
numpy
()
batch_next_token_ids_cpu
=
batch_next_token_ids
.
cpu
().
numpy
()
...
@@ -682,18 +680,7 @@ class Batch:
...
@@ -682,18 +680,7 @@ class Batch:
req
.
regex_fsm_state
,
batch_next_token_ids_cpu
[
i
]
req
.
regex_fsm_state
,
batch_next_token_ids_cpu
[
i
]
)
)
return
batch_next_token_ids
,
batch_next_token_probs
return
batch_next_token_ids
def
_top_p_top_k
(
probs
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
):
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
return
probs_sort
,
probs_idx
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/controller/tp_worker.py
View file @
3de2f30a
...
@@ -451,7 +451,7 @@ class ModelTpServer:
...
@@ -451,7 +451,7 @@ class ModelTpServer:
# Forward and sample the next tokens
# Forward and sample the next tokens
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
output
.
next_token_logprobs
is
not
None
:
...
@@ -574,7 +574,7 @@ class ModelTpServer:
...
@@ -574,7 +574,7 @@ class ModelTpServer:
# Forward and sample the next tokens
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
,
_
=
batch
.
sample
(
output
.
next_token_logits
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
# Move logprobs to cpu
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
if
output
.
next_token_logprobs
is
not
None
:
...
...
python/sglang/srt/server.py
View file @
3de2f30a
...
@@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if
not
server_args
.
disable_flashinfer
:
if
not
server_args
.
disable_flashinfer
:
assert_pkg_version
(
assert_pkg_version
(
"flashinfer"
,
"flashinfer"
,
"0.
0.8
"
,
"0.
1.0
"
,
"Please uninstall the old version and "
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html."
,
"at https://docs.flashinfer.ai/installation.html."
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment