Unverified Commit 5e3f7e7f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor: improve sampler & remove unused fields from model_config.py (#11531)

parent 728af887
...@@ -65,6 +65,7 @@ jobs: ...@@ -65,6 +65,7 @@ jobs:
arm64_tag: dev-arm64 arm64_tag: dev-arm64
steps: steps:
- uses: docker/setup-buildx-action@v3 - uses: docker/setup-buildx-action@v3
- uses: docker/login-action@v2 - uses: docker/login-action@v2
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
...@@ -72,9 +73,10 @@ jobs: ...@@ -72,9 +73,10 @@ jobs:
- run: | - run: |
docker buildx imagetools create \ docker buildx imagetools create \
-t lmsysorg/sglang:${{ matrix.variant.tag }} \ -t lmsysorg/sglang:${{ matrix.variant.tag }} \
-t lmsysorg/sglang:nightly-${{ matrix.variant.tag }}-${{ github.sha }} \ -t lmsysorg/sglang:nightly-${{ matrix.variant.tag }}-$(date +%Y%m%d)-${{ github.sha:0:8 }} \
lmsysorg/sglang:${{ matrix.variant.x86_tag }} \ lmsysorg/sglang:${{ matrix.variant.x86_tag }} \
lmsysorg/sglang:${{ matrix.variant.arm64_tag }} lmsysorg/sglang:${{ matrix.variant.arm64_tag }}
- name: Cleanup Old Nightly Builds - name: Cleanup Old Nightly Builds
run: | run: |
# Get JWT token for Docker Hub API # Get JWT token for Docker Hub API
......
...@@ -25,7 +25,7 @@ from transformers import PretrainedConfig ...@@ -25,7 +25,7 @@ from transformers import PretrainedConfig
from sglang.srt.environ import envs from sglang.srt.environ import envs
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_hip, retry from sglang.srt.utils import is_hip
from sglang.srt.utils.hf_transformers_utils import ( from sglang.srt.utils.hf_transformers_utils import (
get_config, get_config,
get_context_length, get_context_length,
...@@ -86,11 +86,11 @@ class ModelConfig: ...@@ -86,11 +86,11 @@ class ModelConfig:
dtype: str = "auto", dtype: str = "auto",
quantization: Optional[str] = None, quantization: Optional[str] = None,
modelopt_quant: Optional[Union[str, Dict]] = None, modelopt_quant: Optional[Union[str, Dict]] = None,
modelopt_checkpoint_restore_path: Optional[str] = None,
modelopt_checkpoint_save_path: Optional[str] = None,
override_config_file: Optional[str] = None, override_config_file: Optional[str] = None,
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None, hybrid_kvcache_ratio: Optional[
float
] = None, # TODO: remove this, it is not a model config
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
sampling_defaults: str = "openai", sampling_defaults: str = "openai",
) -> None: ) -> None:
......
...@@ -92,6 +92,12 @@ class Sampler(nn.Module): ...@@ -92,6 +92,12 @@ class Sampler(nn.Module):
if return_logprob: if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1) logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
can_sample_directly_from_probs = (
not sampling_info.need_top_p_sampling
and not sampling_info.need_top_k_sampling
and not sampling_info.need_min_p_sampling
)
# If requested, cache probabilities from original logits before temperature scaling. # If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB: if return_logprob and RETURN_ORIGINAL_LOGPROB:
probs_without_temp_scaling = torch.softmax(logits, dim=-1) probs_without_temp_scaling = torch.softmax(logits, dim=-1)
...@@ -102,7 +108,14 @@ class Sampler(nn.Module): ...@@ -102,7 +108,14 @@ class Sampler(nn.Module):
probs = logits probs = logits
del logits del logits
if True: # Keep this redundant check to simplify some internal code sync if can_sample_directly_from_probs:
# when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs
batch_next_token_ids = sampling_from_probs_torch(
probs,
sampling_seed=sampling_info.sampling_seed,
positions=positions,
)
else:
if get_global_server_args().sampling_backend == "flashinfer": if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling: if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_k_renorm_prob(probs, sampling_info.top_ks)
......
...@@ -648,7 +648,8 @@ class ModelRunner: ...@@ -648,7 +648,8 @@ class ModelRunner:
// (self.tp_size // self.moe_ep_size) // (self.tp_size // self.moe_ep_size)
) % weight_block_size_n != 0: ) % weight_block_size_n != 0:
raise ValueError( raise ValueError(
f"For qwen3-vl-fp8 models, please make sure ({text_config.moe_intermediate_size=} // ({self.tp_size=} // {self.moe_ep_size=})) % {weight_block_size_n=} == 0" f"For qwen3-vl-fp8 models, please make sure ({text_config.moe_intermediate_size=} // ({self.tp_size=} // {self.moe_ep_size=})) % {weight_block_size_n=} == 0. "
f"You can fix this by using arguments such as `--tp-size 8 --ep-size 8`"
) )
def init_torch_distributed(self): def init_torch_distributed(self):
......
...@@ -17,8 +17,6 @@ import logging ...@@ -17,8 +17,6 @@ import logging
import sre_parse import sre_parse
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6 _SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30 TOP_K_ALL = 1 << 30
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment