Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bbc07c41
"vscode:/vscode.git/clone" did not exist on "ac5a0f048870364126c7c97ed8660306be58609d"
Unverified
Commit
bbc07c41
authored
Jul 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Jul 27, 2024
Browse files
Move sampling logits to float32 (#773)
parent
a036d419
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
9 deletions
+43
-9
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+3
-3
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+40
-6
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
bbc07c41
...
...
@@ -136,7 +136,7 @@ class LogitsProcessor(nn.Module):
last_logits
=
torch
.
matmul
(
last_hidden
,
weight
.
T
)
if
self
.
tp_size
>
1
:
last_logits
=
tensor_model_parallel_all_gather
(
last_logits
)
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
last_logits
=
last_logits
[:,
:
self
.
config
.
vocab_size
]
.
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
last_logits
/=
self
.
config
.
final_logit_softcapping
...
...
@@ -161,9 +161,9 @@ class LogitsProcessor(nn.Module):
all_logits
=
torch
.
matmul
(
hidden_states
,
weight
.
T
)
if
self
.
tp_size
>
1
:
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
]
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
]
.
float
()
all_logprobs
=
all_logits
.
float
()
all_logprobs
=
all_logits
del
all_logits
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
bbc07c41
...
...
@@ -687,13 +687,21 @@ class Batch:
# TODO(lmzheng): apply penalty
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
True
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
((
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
batch_next_token_ids
,
success
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
)
else
:
# Here we provide a slower fallback implementation.
batch_next_token_ids
,
success
=
top_k_top_p_sampling_from_probs_torch
(
probs
,
self
.
top_ks
,
self
.
top_ps
)
if
torch
.
a
ny
(
~
success
):
if
not
torch
.
a
ll
(
success
):
warnings
.
warn
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
...
...
@@ -933,3 +941,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
max_extend_len
=
int
(
torch
.
max
(
extend_seq_lens
))
return
max_seq_len
,
max_extend_len
,
start_loc
,
prefix_lens
def
top_k_top_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
):
"""A top-k and top-k sampling implementation with native pytorch operations."""
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
.
view
(
-
1
,
1
)
]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
:
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int64
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
success
=
torch
.
ones
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment