Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
83e23c69
"examples/pytorch/vscode:/vscode.git/clone" did not exist on "473925966384f49d7beb0101b25e7f21cccedfbc"
Unverified
Commit
83e23c69
authored
Aug 21, 2024
by
Liangsheng Yin
Committed by
GitHub
Aug 21, 2024
Browse files
Improve code style of sampler (#1168)
parent
ac1b74fa
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
269 additions
and
195 deletions
+269
-195
examples/usage/json_decode.py
examples/usage/json_decode.py
+3
-0
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+1
-1
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+101
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-1
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+10
-189
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+16
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-3
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+136
-0
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+0
-0
No files found.
examples/usage/json_decode.py
View file @
83e23c69
...
...
@@ -35,6 +35,9 @@ def character_gen(s, name):
name
+
" is a character in Harry Potter. Please fill in the following information about this character.
\n
"
)
s
+=
"The constrained regex is:
\n
"
s
+=
character_regex
+
"
\n
"
s
+=
"The JSON output is:
\n
"
s
+=
sgl
.
gen
(
"json_output"
,
max_tokens
=
256
,
regex
=
character_regex
)
...
...
python/sglang/bench_latency.py
View file @
83e23c69
...
...
@@ -54,7 +54,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling
.sampling
_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
suppress_other_loggers
...
...
python/sglang/srt/layers/sampler.py
0 → 100644
View file @
83e23c69
import
logging
import
torch
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
from
vllm.model_executor.custom_op
import
CustomOp
# TODO: move this dict to another place
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
logger
=
logging
.
getLogger
(
__name__
)
class
Sampler
(
CustomOp
):
def
__init__
(
self
):
super
().
__init__
()
def
forward_cuda
(
self
,
logits
:
torch
.
Tensor
,
sampling_info
:
SamplingBatchInfo
):
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
sampling_info
.
temperatures
)
if
sampling_info
.
logit_bias
is
not
None
:
logits
.
add_
(
sampling_info
.
logit_bias
)
if
sampling_info
.
vocab_mask
is
not
None
:
logits
=
logits
.
masked_fill
(
~
sampling_info
.
vocab_mask
,
float
(
"-inf"
))
logits
=
sampling_info
.
penalizer_orchestrator
.
apply
(
logits
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
sampling_info
.
min_ps
.
any
():
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
sampling_info
.
top_ps
)
batch_next_token_ids
,
success
=
min_p_sampling_from_probs
(
probs
,
uniform_samples
,
sampling_info
.
min_ps
)
else
:
batch_next_token_ids
,
success
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
)
else
:
# Here we provide a slower fallback implementation.
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
sampling_info
.
top_ks
,
sampling_info
.
top_ps
,
sampling_info
.
min_ps
)
if
not
torch
.
all
(
success
):
logging
.
warning
(
"Sampling failed, fallback to top_k=1 strategy"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
success
,
batch_next_token_ids
,
argmax_ids
)
return
batch_next_token_ids
def
forward_native
():
raise
NotImplementedError
(
"Native forward is not implemented yet."
)
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
min_ps
:
torch
.
Tensor
,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
.
view
(
-
1
,
1
)
]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
success
=
torch
.
ones
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
python/sglang/srt/managers/io_struct.py
View file @
83e23c69
...
...
@@ -23,7 +23,7 @@ from dataclasses import dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling
.sampling
_params
import
SamplingParams
@
dataclass
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
83e23c69
...
...
@@ -20,22 +20,14 @@ from dataclasses import dataclass
from
typing
import
List
,
Optional
,
Union
import
torch
import
torch.distributed
as
dist
from
flashinfer.sampling
import
(
min_p_sampling_from_probs
,
top_k_renorm_prob
,
top_k_top_p_sampling_from_probs
,
top_p_renorm_prob
,
)
from
vllm.distributed
import
get_tensor_model_parallel_group
import
sglang.srt.sampling.penaltylib
as
penaltylib
from
sglang.global_config
import
global_config
from
sglang.srt.constrained
import
RegexGuide
from
sglang.srt.constrained.jump_forward
import
JumpForwardMap
from
sglang.srt.mem_cache.base_prefix_cache
import
BasePrefixCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
INIT_INCREMENTAL_DETOKENIZATION_OFFSET
=
5
...
...
@@ -340,14 +332,6 @@ class ScheduleBatch:
return_logprob
:
bool
=
False
top_logprobs_nums
:
List
[
int
]
=
None
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
logit_bias
:
torch
.
Tensor
=
None
@
classmethod
def
init_new
(
cls
,
reqs
,
req_to_token_pool
,
token_to_kv_pool
,
tree_cache
):
return_logprob
=
any
(
req
.
return_logprob
for
req
in
reqs
)
...
...
@@ -395,46 +379,6 @@ class ScheduleBatch:
return
out_cache_loc
def
batch_sampling_params
(
self
,
vocab_size
):
device
=
"cuda"
bs
,
reqs
=
self
.
batch_size
(),
self
.
reqs
self
.
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
).
view
(
-
1
,
1
)
self
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
self
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
self
.
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
self
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
self
,
device
=
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
# Handle logit bias but only allocate when needed
self
.
logit_bias
=
None
def
prepare_for_extend
(
self
,
vocab_size
:
int
):
bs
=
self
.
batch_size
()
reqs
=
self
.
reqs
...
...
@@ -475,7 +419,7 @@ class ScheduleBatch:
self
.
top_logprobs_nums
=
[
r
.
top_logprobs_num
for
r
in
reqs
]
self
.
prefix_lens_cpu
=
[
len
(
r
.
prefix_indices
)
for
r
in
reqs
]
self
.
batch_
sampling_
params
(
vocab_size
)
self
.
sampling_
info
=
SamplingBatchInfo
.
from_schedule_batch
(
self
,
vocab_size
)
def
mix_with_running
(
self
,
running_batch
:
"ScheduleBatch"
):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
...
...
@@ -684,6 +628,8 @@ class ScheduleBatch:
self
.
req_pool_indices
,
self
.
seq_lens
-
1
]
=
self
.
out_cache_loc
self
.
sampling_info
.
update_regex_vocab_mask
(
self
)
def
filter_batch
(
self
,
unfinished_indices
:
List
[
int
]):
if
unfinished_indices
is
None
or
len
(
unfinished_indices
)
==
0
:
# Filter out all requests
...
...
@@ -704,24 +650,13 @@ class ScheduleBatch:
self
.
top_logprobs_nums
=
[
self
.
top_logprobs_nums
[
i
]
for
i
in
unfinished_indices
]
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
"logit_bias"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
if
self_val
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
self
.
sampling_info
.
filter
(
unfinished_indices
,
new_indices
)
def
merge
(
self
,
other
:
"ScheduleBatch"
):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
self
.
sampling_info
.
merge
(
other
.
sampling_info
)
self
.
reqs
.
extend
(
other
.
reqs
)
...
...
@@ -736,125 +671,11 @@ class ScheduleBatch:
self
.
top_logprobs_nums
.
extend
(
other
.
top_logprobs_nums
)
self
.
return_logprob
=
any
(
req
.
return_logprob
for
req
in
self
.
reqs
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
def
sample
(
self
,
logits
:
torch
.
Tensor
):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits
=
logits
.
contiguous
()
logits
.
div_
(
self
.
temperatures
)
if
self
.
logit_bias
is
not
None
:
logits
.
add_
(
self
.
logit_bias
)
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
self
.
reqs
)
if
has_regex
:
allowed_mask
=
torch
.
empty_like
(
logits
[
0
],
dtype
=
torch
.
bool
)
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
regex_fsm
is
not
None
:
allowed_mask
.
zero_
()
allowed_mask
[
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
]
=
1
logits
[
i
].
masked_fill_
(
~
allowed_mask
,
float
(
"-inf"
))
logits
=
self
.
penalizer_orchestrator
.
apply
(
logits
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
if
not
global_server_args_dict
[
"disable_flashinfer_sampling"
]:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
)
if
self
.
min_ps
.
any
():
probs
=
top_k_renorm_prob
(
probs
,
self
.
top_ks
)
probs
=
top_p_renorm_prob
(
probs
,
self
.
top_ps
)
batch_next_token_ids
,
success
=
min_p_sampling_from_probs
(
probs
,
uniform_samples
,
self
.
min_ps
)
else
:
batch_next_token_ids
,
success
=
top_k_top_p_sampling_from_probs
(
probs
,
uniform_samples
,
self
.
top_ks
,
self
.
top_ps
)
else
:
# Here we provide a slower fallback implementation.
batch_next_token_ids
,
success
=
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
,
self
.
top_ks
,
self
.
top_ps
,
self
.
min_ps
)
from
sglang.srt.layers.sampler
import
Sampler
if
not
torch
.
all
(
success
):
logger
.
warning
(
f
"Sampling failed. Fallback to top_k=1 strategy.
{
logits
=
}
"
)
probs
=
probs
.
masked_fill
(
torch
.
isnan
(
probs
),
0.0
)
argmax_ids
=
torch
.
argmax
(
probs
,
dim
=-
1
)
batch_next_token_ids
=
torch
.
where
(
success
,
batch_next_token_ids
,
argmax_ids
)
if
has_regex
:
batch_next_token_ids_cpu
=
batch_next_token_ids
.
cpu
().
numpy
()
for
i
,
req
in
enumerate
(
self
.
reqs
):
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
batch_next_token_ids_cpu
[
i
]
)
sampler
=
Sampler
()
self
.
penalizer_orchestrator
.
cumulate_output_tokens
(
batch_next_token_ids
)
batch_next_token_ids
=
sampler
(
logits
,
self
.
sampling_info
)
return
batch_next_token_ids
def
top_k_top_p_min_p_sampling_from_probs_torch
(
probs
:
torch
.
Tensor
,
top_ks
:
torch
.
Tensor
,
top_ps
:
torch
.
Tensor
,
min_ps
:
torch
.
Tensor
,
):
"""A top-k, top-p and min-p sampling implementation with native pytorch operations."""
probs_sort
,
probs_idx
=
probs
.
sort
(
dim
=-
1
,
descending
=
True
)
probs_sum
=
torch
.
cumsum
(
probs_sort
,
dim
=-
1
)
min_p_thresholds
=
probs_sort
[:,
0
]
*
min_ps
probs_sort
[(
probs_sum
-
probs_sort
)
>
top_ps
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
[
torch
.
arange
(
0
,
probs
.
shape
[
-
1
],
device
=
probs
.
device
).
view
(
1
,
-
1
)
>=
top_ks
.
view
(
-
1
,
1
)
]
=
0.0
probs_sort
[
probs_sort
<
min_p_thresholds
.
view
(
-
1
,
1
)]
=
0.0
probs_sort
.
div_
(
probs_sort
.
max
(
dim
=-
1
,
keepdim
=
True
)[
0
])
try
:
sampled_index
=
torch
.
multinomial
(
probs_sort
,
num_samples
=
1
)
except
RuntimeError
as
e
:
logger
.
warning
(
f
"Sampling error:
{
e
}
"
)
batch_next_token_ids
=
torch
.
zeros
(
(
probs_sort
.
shape
[
0
],),
dtype
=
torch
.
int32
,
device
=
probs
.
device
)
success
=
torch
.
zeros
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
batch_next_token_ids
=
torch
.
gather
(
probs_idx
,
dim
=
1
,
index
=
sampled_index
).
view
(
-
1
)
success
=
torch
.
ones
(
probs
.
shape
[
0
],
dtype
=
torch
.
bool
,
device
=
probs
.
device
)
return
batch_next_token_ids
,
success
python/sglang/srt/managers/tokenizer_manager.py
View file @
83e23c69
...
...
@@ -50,7 +50,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightReqOutput
,
)
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.sampling
.sampling
_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
is_generation_model
,
is_multimodal_model
,
load_image
from
sglang.utils
import
get_exception_traceback
...
...
python/sglang/srt/managers/tp_worker.py
View file @
83e23c69
...
...
@@ -482,6 +482,9 @@ class ModelTpServer:
if
batch
.
extend_num_tokens
!=
0
:
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -514,6 +517,11 @@ class ModelTpServer:
req
.
output_ids
.
append
(
next_token_ids
[
i
])
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_ids
[
i
]
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
elif
req
not
in
decoding_reqs
:
...
...
@@ -642,6 +650,9 @@ class ModelTpServer:
# Forward and sample the next tokens
output
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
DECODE
)
next_token_ids
=
batch
.
sample
(
output
.
next_token_logits
)
batch
.
sampling_info
.
penalizer_orchestrator
.
cumulate_output_tokens
(
next_token_ids
)
# Move logprobs to cpu
if
output
.
next_token_logprobs
is
not
None
:
...
...
@@ -658,6 +669,11 @@ class ModelTpServer:
req
.
output_ids
.
append
(
next_token_id
)
req
.
check_finished
()
if
req
.
regex_fsm
is
not
None
:
req
.
regex_fsm_state
=
req
.
regex_fsm
.
get_next_state
(
req
.
regex_fsm_state
,
next_token_id
)
if
req
.
finished
():
self
.
tree_cache
.
cache_finished_req
(
req
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
83e23c69
...
...
@@ -120,9 +120,6 @@ class ModelRunner:
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
self
.
tp_group
=
get_tp_group
()
self
.
is_multi_node_tp
=
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)
)
if
self
.
tp_size
>
1
:
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
0 → 100644
View file @
83e23c69
from
__future__
import
annotations
import
dataclasses
from
typing
import
TYPE_CHECKING
,
List
import
torch
import
sglang.srt.sampling.penaltylib
as
penaltylib
if
TYPE_CHECKING
:
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
@
dataclasses
.
dataclass
class
SamplingBatchInfo
:
# Basic Info
vocab_size
:
int
# Batched sampling params
temperatures
:
torch
.
Tensor
=
None
top_ps
:
torch
.
Tensor
=
None
top_ks
:
torch
.
Tensor
=
None
min_ps
:
torch
.
Tensor
=
None
penalizer_orchestrator
:
penaltylib
.
BatchedPenalizerOrchestrator
=
None
logit_bias
:
torch
.
Tensor
=
None
vocab_mask
:
torch
.
Tensor
=
None
@
classmethod
def
from_schedule_batch
(
cls
,
batch
:
ScheduleBatch
,
vocab_size
:
int
):
device
=
"cuda"
reqs
=
batch
.
reqs
ret
=
cls
(
vocab_size
=
vocab_size
)
ret
.
temperatures
=
torch
.
tensor
(
[
r
.
sampling_params
.
temperature
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
,
).
view
(
-
1
,
1
)
ret
.
top_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
ret
.
top_ks
=
torch
.
tensor
(
[
r
.
sampling_params
.
top_k
for
r
in
reqs
],
dtype
=
torch
.
int
,
device
=
device
)
ret
.
min_ps
=
torch
.
tensor
(
[
r
.
sampling_params
.
min_p
for
r
in
reqs
],
dtype
=
torch
.
float
,
device
=
device
)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
# should not add hefty computation overhead other than simple checks.
#
# While we choose not to even create the class instances if they are not required, this
# could add additional complexity to the {ScheduleBatch} class, especially we need to
# handle {filter_batch()} and {merge()} cases as well.
ret
.
penalizer_orchestrator
=
penaltylib
.
BatchedPenalizerOrchestrator
(
vocab_size
=
vocab_size
,
batch
=
batch
,
device
=
device
,
Penalizers
=
{
penaltylib
.
BatchedFrequencyPenalizer
,
penaltylib
.
BatchedMinNewTokensPenalizer
,
penaltylib
.
BatchedPresencePenalizer
,
penaltylib
.
BatchedRepetitionPenalizer
,
},
)
# Handle logit bias but only allocate when needed
ret
.
logit_bias
=
None
ret
.
update_regex_vocab_mask
(
batch
)
return
ret
def
update_regex_vocab_mask
(
self
,
batch
:
ScheduleBatch
):
bs
,
reqs
=
batch
.
batch_size
(),
batch
.
reqs
device
=
"cuda"
has_regex
=
any
(
req
.
regex_fsm
is
not
None
for
req
in
reqs
)
# Reset the vocab mask
self
.
vocab_mask
=
None
if
has_regex
:
for
i
,
req
in
enumerate
(
reqs
):
if
req
.
regex_fsm
is
not
None
:
if
self
.
vocab_mask
is
None
:
self
.
vocab_mask
=
torch
.
zeros
(
bs
,
self
.
vocab_size
,
dtype
=
torch
.
bool
,
device
=
device
)
self
.
vocab_mask
[
i
][
req
.
regex_fsm
.
get_next_instruction
(
req
.
regex_fsm_state
).
tokens
]
=
1
def
filter
(
self
,
unfinished_indices
:
List
[
int
],
new_indices
:
torch
.
Tensor
):
self
.
penalizer_orchestrator
.
filter
(
unfinished_indices
,
new_indices
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
"logit_bias"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
if
self_val
is
not
None
:
# logit_bias can be None
setattr
(
self
,
item
,
self_val
[
new_indices
])
def
merge
(
self
,
other
:
"SamplingBatchInfo"
):
self
.
penalizer_orchestrator
.
merge
(
other
.
penalizer_orchestrator
)
for
item
in
[
"temperatures"
,
"top_ps"
,
"top_ks"
,
"min_ps"
,
]:
self_val
=
getattr
(
self
,
item
,
None
)
other_val
=
getattr
(
other
,
item
,
None
)
setattr
(
self
,
item
,
torch
.
concat
([
self_val
,
other_val
]))
# logit_bias can be None
if
self
.
logit_bias
is
not
None
or
other
.
logit_bias
is
not
None
:
vocab_size
=
(
self
.
logit_bias
.
shape
[
1
]
if
self
.
logit_bias
is
not
None
else
other
.
logit_bias
.
shape
[
1
]
)
if
self
.
logit_bias
is
None
:
self
.
logit_bias
=
torch
.
zeros
(
(
len
(
self
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
other
.
logit_bias
is
None
:
other
.
logit_bias
=
torch
.
zeros
(
(
len
(
other
.
reqs
),
vocab_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
self
.
logit_bias
=
torch
.
concat
([
self
.
logit_bias
,
other
.
logit_bias
])
python/sglang/srt/sampling_params.py
→
python/sglang/srt/sampling
/sampling
_params.py
View file @
83e23c69
File moved
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment