Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c71880f8
"server/vscode:/vscode.git/clone" did not exist on "404ed7a1f6cfef02ac8ee71c934f87b056d3a06c"
Unverified
Commit
c71880f8
authored
Jul 28, 2024
by
Ying Sheng
Committed by
GitHub
Jul 28, 2024
Browse files
Vectorize logprobs computation (#787)
parent
bcb6611a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
17 deletions
+36
-17
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+25
-12
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+11
-5
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
c71880f8
...
...
@@ -77,33 +77,46 @@ class LogitsProcessor(nn.Module):
@
staticmethod
def
get_top_logprobs
(
all_logprobs
,
logits_metadata
:
LogitsMetadata
):
# TODO: vectorize the code below
if
logits_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
output_top_logprobs
=
[]
for
i
in
range
(
all_logprobs
.
shape
[
0
]):
k
=
logits_metadata
.
top_logprobs_nums
[
i
]
t
=
all_logprobs
[
i
].
topk
(
k
)
v_cpu
=
t
.
valu
es
.
tolist
()
p_cpu
=
t
.
indices
.
tolist
()
output_top_logprobs
.
append
(
list
(
zip
(
v
_cpu
,
p_cpu
)))
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
(
)
indices
=
ret
.
indic
es
.
tolist
()
for
i
,
k
in
enumerate
(
logits_metadata
.
top_logprobs_nums
):
output_top_logprobs
.
append
(
list
(
zip
(
v
alues
[
i
][:
k
],
indices
[
i
][:
k
]
)))
return
None
,
output_top_logprobs
else
:
# TODO: vectorize the code below
input_top_logprobs
,
output_top_logprobs
=
[],
[]
pt
=
0
extend_seq_lens_cpu
=
logits_metadata
.
extend_seq_lens
.
tolist
()
max_k
=
max
(
logits_metadata
.
top_logprobs_nums
)
ret
=
all_logprobs
.
topk
(
max_k
,
dim
=
1
)
values
=
ret
.
values
.
tolist
()
indices
=
ret
.
indices
.
tolist
()
for
i
,
extend_seq_len
in
enumerate
(
extend_seq_lens_cpu
):
if
extend_seq_len
==
0
:
input_top_logprobs
.
append
([])
output_top_logprobs
.
append
([])
continue
k
=
logits_metadata
.
top_logprobs_nums
[
i
]
t
=
all_logprobs
[
pt
:
pt
+
extend_seq_len
].
topk
(
k
)
vs_cpu
=
t
.
values
.
tolist
()
ps_cpu
=
t
.
indices
.
tolist
()
input_top_logprobs
.
append
(
[
list
(
zip
(
vs_cpu
[
j
],
ps_cpu
[
j
]))
for
j
in
range
(
len
(
vs_cpu
)
-
1
)]
[
list
(
zip
(
values
[
pt
+
j
][:
k
],
indices
[
pt
+
j
][:
k
]))
for
j
in
range
(
extend_seq_len
-
1
)
]
)
output_top_logprobs
.
append
(
list
(
zip
(
values
[
pt
+
extend_seq_len
-
1
][:
k
],
indices
[
pt
+
extend_seq_len
-
1
][:
k
],
)
)
)
output_top_logprobs
.
append
(
list
(
zip
(
vs_cpu
[
-
1
],
ps_cpu
[
-
1
])))
pt
+=
extend_seq_len
return
input_top_logprobs
,
output_top_logprobs
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
c71880f8
...
...
@@ -6,7 +6,7 @@ import dataclasses
import
logging
import
multiprocessing
as
mp
import
os
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Tuple
import
numpy
as
np
import
transformers
...
...
@@ -469,7 +469,9 @@ class TokenizerManager:
)
return
ret
def
detokenize_logprob_tokens
(
self
,
token_logprobs
,
decode_to_text
:
bool
):
def
detokenize_logprob_tokens
(
self
,
token_logprobs
:
List
[
Tuple
[
float
,
int
]],
decode_to_text
:
bool
):
if
not
decode_to_text
:
return
[(
logprob
,
token_id
,
None
)
for
logprob
,
token_id
in
token_logprobs
]
...
...
@@ -481,9 +483,13 @@ class TokenizerManager:
]
def
detokenize_top_logprobs_tokens
(
self
,
top_logprobs
,
decode_to_text
:
bool
):
for
i
,
t
in
enumerate
(
top_logprobs
):
if
t
:
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
t
,
decode_to_text
)
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
# We should batch all top-k tokens in all positions.
for
i
,
token_top_logprobs
in
enumerate
(
top_logprobs
):
if
token_top_logprobs
:
top_logprobs
[
i
]
=
self
.
detokenize_logprob_tokens
(
token_top_logprobs
,
decode_to_text
)
return
top_logprobs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment