Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
83e23c69
Unverified
Commit
83e23c69
authored
Aug 21, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 21, 2024
Browse files
Improve code style of sampler (#1168)
parent
ac1b74fa
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
269 additions
and
195 deletions
+269
-195
examples/usage/json_decode.py
examples/usage/json_decode.py
+3
-0
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+101
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-189
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+16
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+136
-0
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+0
-0
No files found.
examples/usage/json_decode.py
View file @
83e23c69
...
...
@@ -35,6 +35,9 @@ def character_gen(s, name):
name
+
" is a character in Harry Potter. Please fill in the following information about this character.
\n
"
)
s
+=
"The constrained regex is:
\n
"
s
+=
character_regex
+
"
\n
"
s
+=
"The JSON output is:
\n
"
s
+=
sgl
.
gen
(
"json_output"
,
max_tokens
=
256
,
regex
=
character_regex
)
...
...
python/sglang/bench_latency.py
View file @
83e23c69
...
...
@@ -54,7 +54,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling
.sampling
_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
suppress_other_loggers
...
...
python/sglang/srt/layers/sampler.py
0 → 100644
View file @
83e23c69
import
logging
import
torch
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
from
vllm.model_executor.custom_op
import
CustomOp
# TODO: move this dict to another place
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
super
().
__init__
()
def
forward_cuda
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
~
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
sampling_info
.
penalizer_orchestrator
.
apply
(
logits
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
sampling_info
.
min_ps
.
any
():
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
batch_next_token_ids
,
success
=
min_p_sampling_from_probs
(
probs
,
uniform_samples
,
sampling_info
.
min_ps
)
else
:
batch_next_token_ids
,
success
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
)
else
:
# Here we provide a slower fallback implementation.
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
if
not
torch
.
all
(
success
):
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
success
,
batch_next_token_ids
,
argmax_ids
)
return
batch_next_token_ids
def
forward_native
():
raise
NotImplementedError
(
"Native forward is not implemented yet."
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
min_ps
:
torch
.
Tensor
,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
.
view
(
-
1
,
1
)
]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
success
=
torch
.
ones
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
python/sglang/srt/managers/io_struct.py
View file @
83e23c69
...
...
@@ -23,7 +23,7 @@ from dataclasses import dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling
.sampling
_params
import
SamplingParams
@
dataclass
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
83e23c69
...
...
@@ -20,22 +20,14 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
from
vllm.distributed
import
get_tensor_model_parallel_group
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.global_config
import
global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -340,14 +332,6 @@ class ScheduleBatch:
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
logit_bias
:
torch
.
Tensor
=
None
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
...
...
@@ -395,46 +379,6 @@ class ScheduleBatch:
return
out_cache_loc
def
batch_sampling_params
(
self
,
vocab_size
):
device
=
"cuda"
bs
,
reqs
=
self
.
batch_size
(),
self
.
reqs
self
.
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
).
view
(
-
1
,
1
)
self
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
self
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
self
.
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
self
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
self
,
device
=
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
# Handle logit bias but only allocate when needed
self
.
logit_bias
=
None
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
...
...
@@ -475,7 +419,7 @@ class ScheduleBatch:
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
batch_
sampling_
params
(
vocab_size
)
self
.
sampling_
info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
...
...
@@ -684,6 +628,8 @@ class ScheduleBatch:
self
.
req_pool_indices
,
self
.
seq_lens
-
1
]
=
self
.
out_cache_loc
self
.
sampling_info
.
update_regex_vocab_mask
(
self
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
# Filter out all requests
...
...
@@ -704,24 +650,13 @@ class ScheduleBatch:
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
"logit_bias"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
if
self_val
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
self
.
sampling_info
.
filter
(
unfinished_indices
,
new_indices
)
def
merge
(
self
,
other
:
"ScheduleBatch"
):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
self
.
sampling_info
.
merge
(
other
.
sampling_info
)
self
.
reqs
.
extend
(
other
.
reqs
)
...
...
@@ -736,125 +671,11 @@ class ScheduleBatch:
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
reqs
)
if
has_regex
:
allowed_mask
=
torch
.
empty_like
(
logits
[
0
],
dtype
=
torch
.
bool
)
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
regex_fsm
is
not
None
:
allowed_mask
.
zero_
()
allowed_mask
[
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
]
=
1
logits
[
i
].
masked_fill_
(
~
allowed_mask
,
float
(
"-inf"
))
logits
=
self
.
penalizer_orchestrator
.
apply
(
logits
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
self
.
min_ps
.
any
():
probs
=
top_k_renorm_prob
(
probs
,
self
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
self
.
top_ps
)
batch_next_token_ids
,
success
=
min_p_sampling_from_probs
(
probs
,
uniform_samples
,
self
.
min_ps
)
else
:
batch_next_token_ids
,
success
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
)
else
:
# Here we provide a slower fallback implementation.
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
self
.
top_ks
,
self
.
top_ps
,
self
.
min_ps
)
from
sglang.srt.layers.sampler
import
Sampler
if
not
torch
.
all
(
success
):
logger
.
warning
(
f
"Sampling failed. Fallback to top_k=1 strategy.
{
logits
=
}
"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
success
,
batch_next_token_ids
,
argmax_ids
)
if
has_regex
:
batch_next_token_ids_cpu
=
batch_next_token_ids
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
batch_next_token_ids_cpu
[
i
]
)
sampler
=
Sampler
()
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
batch_next_token_ids
=
sampler
(
logits
,
self
.
sampling_info
)
return
batch_next_token_ids
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
min_ps
:
torch
.
Tensor
,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
.
view
(
-
1
,
1
)
]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
success
=
torch
.
ones
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
python/sglang/srt/managers/tokenizer_manager.py
View file @
83e23c69
...
...
@@ -50,7 +50,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightReqOutput
,
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling
.sampling
_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
is_generation_model
,
is_multimodal_model
,
load_image
from
sglang.utils
import
get_exception_traceback
...
...
python/sglang/srt/managers/tp_worker.py
View file @
83e23c69
...
...
@@ -482,6 +482,9 @@ class ModelTpServer:
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -514,6 +517,11 @@ class ModelTpServer:
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_ids
[
i
]
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
...
...
@@ -642,6 +650,9 @@ class ModelTpServer:
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -658,6 +669,11 @@ class ModelTpServer:
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
83e23c69
...
...
@@ -120,9 +120,6 @@ class ModelRunner:
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
self
.
tp_group
=
get_tp_group
()
self
.
is_multi_node_tp
=
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)
)
if
self
.
tp_size
>
1
:
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
0 → 100644
View file @
83e23c69
from
__future__
import
annotations
import
dataclasses
from
typing
import
TYPE_CHECKING
,
List
import
torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
@
dataclasses
.
dataclass
class
SamplingBatchInfo
:
# Basic Info
vocab_size
:
int
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
device
=
"cuda"
reqs
=
batch
.
reqs
ret
=
cls
(
vocab_size
=
vocab_size
)
ret
.
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
).
view
(
-
1
,
1
)
ret
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
ret
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
ret
.
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
batch
,
device
=
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
ret
.
update_regex_vocab_mask
(
batch
)
return
ret
def
update_regex_vocab_mask
(
self
,
batch
:
ScheduleBatch
):
bs
,
reqs
=
batch
.
batch_size
(),
batch
.
reqs
device
=
"cuda"
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
reqs
)
# Reset the vocab mask
self
.
vocab_mask
=
None
if
has_regex
:
for
i
,
req
in
enumerate
(
reqs
):
if
req
.
regex_fsm
is
not
None
:
if
self
.
vocab_mask
is
None
:
self
.
vocab_mask
=
torch
.
zeros
(
bs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
self
.
vocab_mask
[
i
][
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
]
=
1
def
filter
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
"logit_bias"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
if
self_val
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
def
merge
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
python/sglang/srt/sampling_params.py
→
python/sglang/srt/sampling
/sampling
_params.py
View file @
83e23c69
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment