Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5e3f7e7f
Unverified
Commit
5e3f7e7f
authored
Oct 13, 2025
by
Lianmin Zheng
Committed by
GitHub
Oct 13, 2025
Browse files
Minor: improve sampler & remove unused fields from model_config.py (#11531)
parent
728af887
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
9 deletions
+23
-9
.github/workflows/release-docker-dev.yml
.github/workflows/release-docker-dev.yml
+3
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+4
-4
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+14
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+2
-1
python/sglang/srt/sampling/sampling_params.py
python/sglang/srt/sampling/sampling_params.py
+0
-2
No files found.
.github/workflows/release-docker-dev.yml
View file @
5e3f7e7f
...
...
@@ -65,6 +65,7 @@ jobs:
arm64_tag
:
dev-arm64
steps
:
-
uses
:
docker/setup-buildx-action@v3
-
uses
:
docker/login-action@v2
with
:
username
:
${{ secrets.DOCKERHUB_USERNAME }}
...
...
@@ -72,9 +73,10 @@ jobs:
-
run
:
|
docker buildx imagetools create \
-t lmsysorg/sglang:${{ matrix.variant.tag }} \
-t lmsysorg/sglang:nightly-${{ matrix.variant.tag }}-${{ github.sha }} \
-t lmsysorg/sglang:nightly-${{ matrix.variant.tag }}-$
(date +%Y%m%d)-$
{{ github.sha
:0:8
}} \
lmsysorg/sglang:${{ matrix.variant.x86_tag }} \
lmsysorg/sglang:${{ matrix.variant.arm64_tag }}
-
name
:
Cleanup Old Nightly Builds
run
:
|
# Get JWT token for Docker Hub API
...
...
python/sglang/srt/configs/model_config.py
View file @
5e3f7e7f
...
...
@@ -25,7 +25,7 @@ from transformers import PretrainedConfig
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.quantization
import
QUANTIZATION_METHODS
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
is_hip
,
retry
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils.hf_transformers_utils
import
(
get_config
,
get_context_length
,
...
...
@@ -86,11 +86,11 @@ class ModelConfig:
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
modelopt_quant
:
Optional
[
Union
[
str
,
Dict
]]
=
None
,
modelopt_checkpoint_restore_path
:
Optional
[
str
]
=
None
,
modelopt_checkpoint_save_path
:
Optional
[
str
]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
is_draft_model
:
bool
=
False
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
# TODO: remove this, it is not a model config
model_impl
:
Union
[
str
,
ModelImpl
]
=
ModelImpl
.
AUTO
,
sampling_defaults
:
str
=
"openai"
,
)
->
None
:
...
...
python/sglang/srt/layers/sampler.py
View file @
5e3f7e7f
...
...
@@ -92,6 +92,12 @@ class Sampler(nn.Module):
if
return_logprob
:
logprobs
=
torch
.
nn
.
functional
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
can_sample_directly_from_probs
=
(
not
sampling_info
.
need_top_p_sampling
and
not
sampling_info
.
need_top_k_sampling
and
not
sampling_info
.
need_min_p_sampling
)
# If requested, cache probabilities from original logits before temperature scaling.
if
return_logprob
and
RETURN_ORIGINAL_LOGPROB
:
probs_without_temp_scaling
=
torch
.
softmax
(
logits
,
dim
=-
1
)
...
...
@@ -102,7 +108,14 @@ class Sampler(nn.Module):
probs
=
logits
del
logits
if
True
:
# Keep this redundant check to simplify some internal code sync
if
can_sample_directly_from_probs
:
# when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs
batch_next_token_ids
=
sampling_from_probs_torch
(
probs
,
sampling_seed
=
sampling_info
.
sampling_seed
,
positions
=
positions
,
)
else
:
if
get_global_server_args
().
sampling_backend
==
"flashinfer"
:
if
sampling_info
.
need_min_p_sampling
:
probs
=
top_k_renorm_prob
(
probs
,
sampling_info
.
top_ks
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5e3f7e7f
...
...
@@ -648,7 +648,8 @@ class ModelRunner:
//
(
self
.
tp_size
//
self
.
moe_ep_size
)
)
%
weight_block_size_n
!=
0
:
raise
ValueError
(
f
"For qwen3-vl-fp8 models, please make sure (
{
text_config
.
moe_intermediate_size
=
}
// (
{
self
.
tp_size
=
}
//
{
self
.
moe_ep_size
=
}
)) %
{
weight_block_size_n
=
}
== 0"
f
"For qwen3-vl-fp8 models, please make sure (
{
text_config
.
moe_intermediate_size
=
}
// (
{
self
.
tp_size
=
}
//
{
self
.
moe_ep_size
=
}
)) %
{
weight_block_size_n
=
}
== 0. "
f
"You can fix this by using arguments such as `--tp-size 8 --ep-size 8`"
)
def
init_torch_distributed
(
self
):
...
...
python/sglang/srt/sampling/sampling_params.py
View file @
5e3f7e7f
...
...
@@ -17,8 +17,6 @@ import logging
import
sre_parse
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
sglang.srt.utils
import
get_bool_env_var
_SAMPLING_EPS
=
1e-6
TOP_K_ALL
=
1
<<
30
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment