Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a837166e
"vscode:/vscode.git/clone" did not exist on "04c0b21488e2edbbc1a191d01abc77bd13e2c2a3"
Unverified
Commit
a837166e
authored
Jan 21, 2024
by
Lianmin Zheng
Committed by
GitHub
Jan 21, 2024
Browse files
Fix select and normalized logprobs (#67)
parent
11f3cca6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
101 deletions
+52
-101
python/sglang/srt/layers/get_selected_logprob.py
python/sglang/srt/layers/get_selected_logprob.py
+0
-79
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+45
-15
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+7
-7
No files found.
python/sglang/srt/layers/get_selected_logprob.py
deleted
100644 → 0
View file @
11f3cca6
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.utils
import
wrap_kernel_launcher
@
triton
.
jit
def
_fwd_segmented_gather
(
all_logits
,
len_add_1
,
cum_len
,
input_ids
,
logprobs
,
max_seq_len
,
voc_size
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
cur_req
=
tl
.
program_id
(
0
)
cur_l
=
tl
.
load
(
len_add_1
+
cur_req
)
cum_l
=
tl
.
load
(
cum_len
+
cur_req
)
for
i
in
range
(
0
,
(
max_seq_len
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
):
off
=
i
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
off
<
cur_l
-
1
idx
=
tl
.
load
(
input_ids
+
cum_l
-
cur_l
+
off
+
1
,
mask
=
mask
)
data
=
tl
.
load
(
all_logits
+
(
cum_l
-
cur_l
+
off
)
*
voc_size
+
idx
,
mask
=
mask
)
tl
.
store
(
logprobs
+
cum_l
-
cur_l
-
cur_req
+
off
,
data
,
mask
=
mask
)
cached_kernel
=
None
def
get_selected_logprob
(
all_logits
,
len_add_1
,
input_ids
,
logprobs
):
cum_len
=
torch
.
cumsum
(
len_add_1
,
dtype
=
torch
.
int32
,
dim
=
0
)
voc_size
=
all_logits
.
shape
[
1
]
grid
=
(
len_add_1
.
shape
[
0
],
1
,
1
)
max_seq_len
=
len_add_1
.
max
().
item
()
global
cached_kernel
if
cached_kernel
:
cached_kernel
(
grid
,
4
,
all_logits
,
len_add_1
,
cum_len
,
input_ids
,
logprobs
,
max_seq_len
,
)
return
_fwd_segmented_gather
[
grid
](
all_logits
,
len_add_1
,
cum_len
,
input_ids
,
logprobs
,
max_seq_len
,
voc_size
,
BLOCK_SIZE
=
128
,
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_segmented_gather
)
if
__name__
==
"__main__"
:
all_logits
=
torch
.
tensor
(
# s s s
[[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
len_add_1
=
torch
.
tensor
([
2
,
3
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
logprobs
=
torch
.
empty
((
3
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
get_selected_logprobs
(
all_logits
,
len_add_1
,
input_ids
,
logprobs
)
print
(
logprobs
)
# assert logprobs == [2, 2, 4]
python/sglang/srt/layers/logits_processor.py
View file @
a837166e
import
torch
from
sglang.srt.layers.get_selected_logprob
import
get_selected_logprob
from
sglang.srt.managers.router.model_runner
import
ForwardMode
,
InputMetadata
from
torch
import
nn
from
vllm.model_executor.parallel_utils.communication_op
import
(
...
...
@@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module):
normalized_logprobs
=
compute_normalized_logprobs
(
all_logprobs
,
input_metadata
.
seq_lens
-
input_metadata
.
prefix_lens
,
input_ids
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
extend_start_loc
,
)
last_logits
=
logits
[
last_index
]
return
last_logits
,
normalized_logprobs
def
compute_normalized_logprobs
(
all_logprobs
,
len_add_1
,
input_ids
):
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
logprobs
=
torch
.
zeros
(
(
all_logprobs
.
shape
[
0
]
-
len_add_1
.
shape
[
0
]),
dtype
=
torch
.
float32
,
device
=
"cuda"
def
compute_normalized_logprobs
(
all_logprobs
,
input_ids
,
seq_lens
,
start_loc
):
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
start_loc
.
clone
()
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
return
sum_logp
/
((
seq_lens
-
1
).
clamp
(
min
=
1
))
if
__name__
==
"__main__"
:
all_logprobs
=
torch
.
tensor
(
# s s s
[[
0
,
1
,
2
,
3
],
[
1
,
2
,
3
,
4
],
[
2
,
3
,
4
,
5
],
[
3
,
4
,
5
,
6
],
[
4
,
5
,
6
,
7
]],
dtype
=
torch
.
float32
,
device
=
"cuda"
,
)
get_selected_logprob
(
all_logprobs
,
len_add_1
,
input_ids
,
logprobs
)
cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
end
=
torch
.
cumsum
(
len_add_1
.
sub_
(
1
),
dim
=
0
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
end
[:
-
1
]),
0
)
end
.
sub_
(
1
)
torch
.
cuda
.
synchronize
()
sum_logp
=
cumsum
[
end
]
-
cumsum
[
start
]
+
logprobs
[
start
]
res
=
sum_logp
/
len_add_1
return
res
seq_lens
=
torch
.
tensor
([
2
,
0
,
3
,
0
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
input_ids
=
torch
.
tensor
([
1
,
2
,
3
,
0
,
1
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
logprobs
=
torch
.
zeros
(
5
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
logprobs
=
all_logprobs
[
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
logprobs_cumsum
=
torch
.
cumsum
(
logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
len_cumsum
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
start
=
torch
.
cat
((
torch
.
tensor
([
0
],
device
=
"cuda"
),
len_cumsum
[:
-
1
]),
0
)
end
=
start
+
seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
logprobs
[
start
]
# assert logprobs == [2, _, 2, 4, _]
print
(
"logprobs"
,
logprobs
)
print
(
"start"
,
start
)
print
(
"end"
,
end
)
print
(
"sum_logp"
,
sum_logp
)
python/sglang/srt/managers/router/model_runner.py
View file @
a837166e
import
logging
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
List
import
logging
import
numpy
as
np
import
torch
...
...
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from
vllm.model_executor.model_loader
import
_set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
logger
=
logging
.
getLogger
(
"model_runner"
)
...
...
@@ -112,7 +111,7 @@ class InputMetadata:
def
init_extend_args
(
self
):
self
.
extend_seq_lens
=
self
.
seq_lens
-
self
.
prefix_lens
self
.
extend_start_loc
=
torch
.
zeros_like
(
self
.
seq_lens
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
0
)
self
.
extend_start_loc
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
[:
-
1
],
dim
=
0
)
self
.
max_extend_len
=
int
(
torch
.
max
(
self
.
extend_seq_lens
))
@
classmethod
...
...
@@ -262,7 +261,7 @@ class ModelRunner:
if
model_class
is
None
:
raise
ValueError
(
f
"Unsupported architectures:
{
architectures
}
"
)
logger
.
info
(
"
load weight begin."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
:
load weight begin."
)
# Load weights
linear_method
=
None
...
...
@@ -287,7 +286,7 @@ class ModelRunner:
)
self
.
model
=
model
.
eval
()
logger
.
info
(
"
load weight end."
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
:
load weight end."
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
...
...
@@ -308,8 +307,9 @@ class ModelRunner:
self
.
max_total_num_token
=
self
.
profile_max_num_token
(
total_gpu_memory
)
if
self
.
max_total_num_token
<=
0
:
raise
RuntimeError
(
"Not enought memory. "
"Please try to increase --mem-fraction-static."
)
raise
RuntimeError
(
"Not enought memory. "
"Please try to increase --mem-fraction-static."
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
int
(
self
.
max_total_num_token
/
self
.
model_config
.
context_len
*
256
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment