Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f25b76c0
Unverified
Commit
f25b76c0
authored
Jul 08, 2024
by
Liangsheng Yin
Committed by
GitHub
Jul 08, 2024
Browse files
add `LogitsMetadata` (#604)
parent
f4e885b7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
66 additions
and
40 deletions
+66
-40
benchmark/line_retrieval/gen_data.py
benchmark/line_retrieval/gen_data.py
+3
-3
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+50
-19
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-2
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+8
-8
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+0
-5
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+3
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-0
No files found.
benchmark/line_retrieval/gen_data.py
View file @
f25b76c0
...
...
@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
)
for
i
in
redirect_indices
:
target_idx
=
np
.
random
.
choice
(
min
(
i
*
2
+
100
,
num_lines
))
lines
[
i
]
=
(
f
"Line
{
indices
[
i
]
}
: The REGISTER_CONTENT is the same as Line
{
indices
[
target_idx
]
}
."
)
lines
[
i
]
=
f
"Line
{
indices
[
i
]
}
: The REGISTER_CONTENT is the same as Line
{
indices
[
target_idx
]
}
."
redirects
[
i
]
=
target_idx
# Build links and find sources
...
...
python/sglang/srt/layers/logits_processor.py
View file @
f25b76c0
"""Logits processing."""
import
dataclasses
from
typing
import
List
from
typing
import
List
,
Union
import
torch
from
torch
import
nn
...
...
@@ -31,6 +31,27 @@ class LogitProcessorOutput:
decode_top_logprobs
:
List
@
dataclasses
.
dataclass
class
LogitsMetadata
:
forward_mode
:
ForwardMode
extend_seq_lens
:
torch
.
Tensor
extend_start_loc
:
torch
.
Tensor
# For logprobs
return_logprob
:
bool
top_logprobs_nums
:
List
[
int
]
@
classmethod
def
from_input_metadata
(
cls
,
input_metadata
:
InputMetadata
):
return
cls
(
forward_mode
=
input_metadata
.
forward_mode
,
extend_seq_lens
=
input_metadata
.
extend_seq_lens
,
extend_start_loc
=
input_metadata
.
extend_start_loc
,
return_logprob
=
input_metadata
.
return_logprob
,
top_logprobs_nums
=
input_metadata
.
top_logprobs_nums
,
)
class
LogitsProcessor
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
...
...
@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module):
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
_get_normalized_prompt_logprobs
(
self
,
prefill_token_logprobs
,
input
_metadata
:
Input
Metadata
self
,
prefill_token_logprobs
,
logits
_metadata
:
Logits
Metadata
):
logprobs_cumsum
=
torch
.
cumsum
(
prefill_token_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input
_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input
_metadata
.
extend_seq_lens
-
2
start
=
logits
_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
logits
_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
prefill_token_logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
prefill_token_logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
(
...
...
@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module):
+
prefill_token_logprobs
[
start
]
)
normalized_prompt_logprobs
=
sum_logp
/
(
(
input
_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
(
logits
_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
return
normalized_prompt_logprobs
def
_get_top_logprobs
(
self
,
all_logprobs
,
input
_metadata
:
Input
Metadata
):
def
_get_top_logprobs
(
self
,
all_logprobs
,
logits
_metadata
:
Logits
Metadata
):
# TODO: vectorize the code below
if
input
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
decode_top_logprobs
=
[]
for
i
in
range
(
all_logprobs
.
shape
[
0
]):
k
=
input
_metadata
.
top_logprobs_nums
[
i
]
k
=
logits
_metadata
.
top_logprobs_nums
[
i
]
t
=
all_logprobs
[
i
].
topk
(
k
)
v_cpu
=
t
.
values
.
tolist
()
p_cpu
=
t
.
indices
.
tolist
()
...
...
@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module):
else
:
prefill_top_logprobs
,
decode_top_logprobs
=
[],
[]
pt
=
0
extend_seq_lens_cpu
=
input
_metadata
.
extend_seq_lens
.
tolist
()
extend_seq_lens_cpu
=
logits
_metadata
.
extend_seq_lens
.
tolist
()
for
i
,
extend_seq_len
in
enumerate
(
extend_seq_lens_cpu
):
if
extend_seq_len
==
0
:
prefill_top_logprobs
.
append
([])
decode_top_logprobs
.
append
([])
continue
k
=
input
_metadata
.
top_logprobs_nums
[
i
]
k
=
logits
_metadata
.
top_logprobs_nums
[
i
]
t
=
all_logprobs
[
pt
:
pt
+
extend_seq_len
].
topk
(
k
)
vs_cpu
=
t
.
values
.
tolist
()
ps_cpu
=
t
.
indices
.
tolist
()
...
...
@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module):
return
prefill_top_logprobs
,
decode_top_logprobs
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
input_metadata
:
InputMetadata
):
def
forward
(
self
,
input_ids
,
hidden_states
,
weight
,
logits_metadata
:
Union
[
LogitsMetadata
,
InputMetadata
],
):
if
isinstance
(
logits_metadata
,
InputMetadata
):
logits_metadata
=
LogitsMetadata
.
from_input_metadata
(
logits_metadata
)
assert
isinstance
(
logits_metadata
,
LogitsMetadata
)
# Get the last hidden states and last logits for the next token prediction
if
input
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
last_index
=
None
last_hidden
=
hidden_states
else
:
last_index
=
(
torch
.
cumsum
(
input
_metadata
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
long
)
torch
.
cumsum
(
logits
_metadata
.
extend_seq_lens
,
dim
=
0
,
dtype
=
torch
.
long
)
-
1
)
last_hidden
=
hidden_states
[
last_index
]
...
...
@@ -114,7 +145,7 @@ class LogitsProcessor(nn.Module):
last_logits
*=
self
.
config
.
final_logit_softcapping
# Return only last_logits if logprob is not requested
if
not
input
_metadata
.
return_logprob
:
if
not
logits
_metadata
.
return_logprob
:
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
None
,
...
...
@@ -125,7 +156,7 @@ class LogitsProcessor(nn.Module):
)
else
:
# When logprob is requested, compute the logits for all tokens.
if
input
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
all_logits
=
last_logits
else
:
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
...
...
@@ -138,15 +169,15 @@ class LogitsProcessor(nn.Module):
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
# Get the logprob of top-k tokens
return_top_logprob
=
any
(
x
>
0
for
x
in
input
_metadata
.
top_logprobs_nums
)
return_top_logprob
=
any
(
x
>
0
for
x
in
logits
_metadata
.
top_logprobs_nums
)
if
return_top_logprob
:
prefill_top_logprobs
,
decode_top_logprobs
=
self
.
_get_top_logprobs
(
all_logprobs
,
input
_metadata
all_logprobs
,
logits
_metadata
)
else
:
prefill_top_logprobs
=
decode_top_logprobs
=
None
if
input
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
if
logits
_metadata
.
forward_mode
==
ForwardMode
.
DECODE
:
return
LogitProcessorOutput
(
next_token_logits
=
last_logits
,
next_token_logprobs
=
all_logprobs
,
...
...
@@ -166,7 +197,7 @@ class LogitsProcessor(nn.Module):
]
normalized_prompt_logprobs
=
self
.
_get_normalized_prompt_logprobs
(
prefill_token_logprobs
,
input
_metadata
prefill_token_logprobs
,
logits
_metadata
)
return
LogitProcessorOutput
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
f25b76c0
...
...
@@ -2,9 +2,8 @@
import
numpy
as
np
import
torch
from
torch
import
nn
from
flashinfer.cascade
import
merge_state
from
torch
import
nn
from
sglang.global_config
import
global_config
from
sglang.srt.layers.extend_attention
import
extend_attention_fwd
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
f25b76c0
...
...
@@ -334,15 +334,15 @@ class TokenizerManager:
ret
[
"meta_info"
][
"decode_token_logprobs"
],
return_text_in_logprobs
)
if
top_logprobs_num
>
0
:
ret
[
"meta_info"
][
"prefill_top_logprobs"
]
=
(
self
.
detokenize
_top_logprobs
_tokens
(
ret
[
"meta_info"
][
"prefill_top_logprobs"
],
return_text_in_logprobs
)
ret
[
"meta_info"
][
"prefill
_top_logprobs
"
]
=
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"prefill_top_logprobs"
],
return_text_in_logprobs
)
ret
[
"meta_info"
][
"decode_top_logprobs"
]
=
(
self
.
detokeniz
e_top_logprobs
_tokens
(
ret
[
"meta_info"
][
"decode_top_logprobs"
],
return_text_in_logprobs
)
ret
[
"meta_info"
][
"decod
e_top_logprobs
"
]
=
self
.
detokenize_top_logprobs_tokens
(
ret
[
"meta_info"
][
"decode_top_logprobs"
],
return_text_in_logprobs
)
return
ret
...
...
python/sglang/srt/models/gemma2.py
View file @
f25b76c0
...
...
@@ -81,7 +81,6 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class
GemmaRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq
=
1.0
/
(
...
...
@@ -95,7 +94,6 @@ class GemmaRotaryEmbedding(RotaryEmbedding):
class
Gemma2MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
...
...
@@ -127,7 +125,6 @@ class Gemma2MLP(nn.Module):
class
Gemma2Attention
(
nn
.
Module
):
def
__init__
(
self
,
layer_idx
:
int
,
...
...
@@ -218,7 +215,6 @@ class Gemma2Attention(nn.Module):
class
Gemma2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_idx
:
int
,
...
...
@@ -287,7 +283,6 @@ class Gemma2DecoderLayer(nn.Module):
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
python/sglang/srt/models/llama2.py
View file @
f25b76c0
...
...
@@ -163,9 +163,9 @@ class LlamaDecoderLayer(nn.Module):
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
rope_scaling
[
"
original_max_position_embeddings
"
]
=
config
.
original_max_position_embeddings
rope_is_neox_style
=
getattr
(
config
,
"rope_is_neox_style"
,
True
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
LlamaAttention
(
...
...
python/sglang/srt/utils.py
View file @
f25b76c0
...
...
@@ -459,6 +459,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
"""
import
vllm.distributed.device_communicators.custom_all_reduce_utils
as
tgt
setattr
(
tgt
,
"gpu_p2p_access_check"
,
lambda
*
arg
,
**
kwargs
:
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment