Unverified Commit 608668e1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Slightly improve the sampler to skip unnecessary steps (#6956)

parent 6c0a4828
......@@ -5,7 +5,7 @@ import torch
import torch.distributed as dist
from torch import nn
from sglang.srt.distributed import get_tensor_model_parallel_group
from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -30,7 +30,7 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tensor_model_parallel_group().device_group
self.tp_sync_group = get_tp_group().device_group
if global_server_args_dict["enable_dp_attention"]:
self.tp_sync_group = get_attention_tp_group().device_group
......@@ -59,7 +59,7 @@ class Sampler(nn.Module):
# Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor:
self._apply_custom_logit_processor(logits, sampling_info)
apply_custom_logit_processor(logits, sampling_info)
if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
......@@ -81,49 +81,39 @@ class Sampler(nn.Module):
probs = logits
del logits
if global_server_args_dict["sampling_backend"] == "flashinfer":
if return_logprob:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.
# NOTE: OpenAI's logprobs is independent of top-p, we use the
# same rule.
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0]
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_info.min_ps
)
else:
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan = self.use_nan_detection and crash_on_warnings()
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs.contiguous(),
if True: # Keep this redundant check to simplify some internal code sync
if global_server_args_dict["sampling_backend"] == "flashinfer":
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_info.min_ps
)
else:
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
check_nan=self.use_nan_detection,
)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
check_nan=check_nan,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
sampling_info.min_ps,
sampling_info.need_min_p_sampling,
)
if return_logprob:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
# Attach logprobs to logits_output (in-place modification)
if return_logprob:
......@@ -160,39 +150,6 @@ class Sampler(nn.Module):
return batch_next_token_ids
def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
assert logits.shape[0] == len(sampling_batch_info), (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
for _, (
processor,
batch_mask,
) in sampling_batch_info.custom_logit_processor.items():
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)
logger.debug(
f"Custom logit processor {processor.__class__.__name__} is applied."
)
def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
......@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
return batch_next_token_ids
def sampling_from_probs_torch(probs: torch.Tensor):
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
sampled_index = torch.multinomial(probs, num_samples=1)
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
return batch_next_token_ids
def top_p_normalize_probs_torch(
probs: torch.Tensor,
top_ps: torch.Tensor,
......@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
def apply_custom_logit_processor(
logits: torch.Tensor,
sampling_batch_info: SamplingBatchInfo,
num_tokens_in_batch: int = 1,
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place.
num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple
tokens. By default, we assume each batch contains only 1 token.
"""
assert logits.shape[0] == len(sampling_batch_info) * num_tokens_in_batch, (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)}) x num_tokens_in_batch "
f"({num_tokens_in_batch})"
)
for _, (
processor,
batch_mask,
) in sampling_batch_info.custom_logit_processor.items():
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch)
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)
logger.debug(
f"Custom logit processor {processor.__class__.__name__} is applied."
)
......@@ -852,7 +852,7 @@ class TokenizerManager:
obj.load_format = self.server_args.load_format
logger.info("Start update_weights. Load format=%s", obj.load_format)
if True:
if True: # Keep this redundant check to simplify some internal code sync
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
......
......@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......
......@@ -9,10 +9,12 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
logger = logging.getLogger(__name__)
......@@ -27,6 +29,12 @@ class SamplingBatchInfo:
# Whether all requests use greedy sampling
is_all_greedy: bool
# Whether any requests use top_p sampling
need_top_p_sampling: bool
# Whether any requests use top_k sampling
need_top_k_sampling: bool
# Whether any request needs min_p sampling
need_min_p_sampling: bool
......@@ -133,6 +141,8 @@ class SamplingBatchInfo:
top_ks=top_ks,
min_ps=min_ps,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
penalizer_orchestrator=penalizer_orchestrator,
......@@ -167,7 +177,7 @@ class SamplingBatchInfo:
# Apply the mask
for i, grammar in enumerate(self.grammars):
if grammar and not grammar.finished:
if grammar and not grammar.finished and not grammar.is_terminated():
grammar.fill_vocab_mask(self.vocab_mask, i)
# Move the mask to the device if needed
......@@ -308,4 +318,6 @@ class SamplingBatchInfo:
setattr(self, item, torch.cat([self_val, other_val]))
self.is_all_greedy &= other.is_all_greedy
self.need_top_p_sampling |= other.need_top_p_sampling
self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling
......@@ -16,6 +16,7 @@
from typing import Any, Dict, List, Optional, Union
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
class SamplingParams:
......@@ -84,7 +85,7 @@ class SamplingParams:
self.temperature = 1.0
self.top_k = 1
if self.top_k == -1:
self.top_k = 1 << 30 # whole vocabulary
self.top_k = TOP_K_ALL # whole vocabulary
def verify(self):
if self.temperature < 0.0:
......
import dataclasses
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
from typing import Dict, List, Optional, Sequence
import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.communicator import (
CommunicateContext,
CommunicateSimpleFn,
CommunicateSummableTensorPairFn,
ScatterMode,
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
from sglang.srt.managers.schedule_batch import global_server_args_dict
......@@ -20,9 +18,6 @@ from sglang.srt.operations import execute_operations, execute_overlapped_operati
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
logger = logging.getLogger(__name__)
......@@ -46,7 +41,7 @@ def compute_split_seq_index(
assert num_tokens == 0
return 0
else:
raise NotImplementedError
raise NotImplementedError()
def _split_array_by_half_sum(arr: Sequence[int]) -> int:
......
......@@ -1928,16 +1928,18 @@ def next_power_of_2(n: int):
setattr(triton, "next_power_of_2", next_power_of_2)
@contextmanager
def empty_context(*args, **kwargs):
try:
# Setup code goes here
yield
finally:
# Cleanup code goes here
class EmptyContextManager:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def empty_context(*args, **kwargs):
return EmptyContextManager()
def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment