Unverified Commit 608668e1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Slightly improve the sampler to skip unnecessary steps (#6956)

parent 6c0a4828
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from sglang.srt.distributed import get_tensor_model_parallel_group from sglang.srt.distributed import get_tp_group
from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -30,7 +30,7 @@ class Sampler(nn.Module): ...@@ -30,7 +30,7 @@ class Sampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.use_nan_detection = global_server_args_dict["enable_nan_detection"] self.use_nan_detection = global_server_args_dict["enable_nan_detection"]
self.tp_sync_group = get_tensor_model_parallel_group().device_group self.tp_sync_group = get_tp_group().device_group
if global_server_args_dict["enable_dp_attention"]: if global_server_args_dict["enable_dp_attention"]:
self.tp_sync_group = get_attention_tp_group().device_group self.tp_sync_group = get_attention_tp_group().device_group
...@@ -59,7 +59,7 @@ class Sampler(nn.Module): ...@@ -59,7 +59,7 @@ class Sampler(nn.Module):
# Apply the custom logit processors if registered in the sampling info. # Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor: if sampling_info.has_custom_logit_processor:
self._apply_custom_logit_processor(logits, sampling_info) apply_custom_logit_processor(logits, sampling_info)
if self.use_nan_detection and torch.any(torch.isnan(logits)): if self.use_nan_detection and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.") logger.warning("Detected errors during sampling! NaN in the logits.")
...@@ -81,16 +81,8 @@ class Sampler(nn.Module): ...@@ -81,16 +81,8 @@ class Sampler(nn.Module):
probs = logits probs = logits
del logits del logits
if True: # Keep this redundant check to simplify some internal code sync
if global_server_args_dict["sampling_backend"] == "flashinfer": if global_server_args_dict["sampling_backend"] == "flashinfer":
if return_logprob:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.
# NOTE: OpenAI's logprobs is independent of top-p, we use the
# same rule.
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0]
if sampling_info.need_min_p_sampling: if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps) probs = top_p_renorm_prob(probs, sampling_info.top_ps)
...@@ -98,16 +90,13 @@ class Sampler(nn.Module): ...@@ -98,16 +90,13 @@ class Sampler(nn.Module):
probs, sampling_info.min_ps probs, sampling_info.min_ps
) )
else: else:
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan = self.use_nan_detection and crash_on_warnings()
batch_next_token_ids = top_k_top_p_sampling_from_probs( batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs.contiguous(), probs,
sampling_info.top_ks, sampling_info.top_ks,
sampling_info.top_ps, sampling_info.top_ps,
filter_apply_order="joint", filter_apply_order="joint",
check_nan=check_nan, check_nan=self.use_nan_detection,
) )
elif global_server_args_dict["sampling_backend"] == "pytorch": elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations. # A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
...@@ -117,14 +106,15 @@ class Sampler(nn.Module): ...@@ -117,14 +106,15 @@ class Sampler(nn.Module):
sampling_info.min_ps, sampling_info.min_ps,
sampling_info.need_min_p_sampling, sampling_info.need_min_p_sampling,
) )
if return_logprob:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
else: else:
raise ValueError( raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
) )
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
# Attach logprobs to logits_output (in-place modification) # Attach logprobs to logits_output (in-place modification)
if return_logprob: if return_logprob:
if any(x > 0 for x in top_logprobs_nums): if any(x > 0 for x in top_logprobs_nums):
...@@ -160,39 +150,6 @@ class Sampler(nn.Module): ...@@ -160,39 +150,6 @@ class Sampler(nn.Module):
return batch_next_token_ids return batch_next_token_ids
def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
assert logits.shape[0] == len(sampling_batch_info), (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
for _, (
processor,
batch_mask,
) in sampling_batch_info.custom_logit_processor.items():
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)
logger.debug(
f"Custom logit processor {processor.__class__.__name__} is applied."
)
def top_k_top_p_min_p_sampling_from_probs_torch( def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor, probs: torch.Tensor,
...@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch( ...@@ -221,6 +178,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
return batch_next_token_ids return batch_next_token_ids
def sampling_from_probs_torch(probs: torch.Tensor):
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
sampled_index = torch.multinomial(probs, num_samples=1)
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
return batch_next_token_ids
def top_p_normalize_probs_torch( def top_p_normalize_probs_torch(
probs: torch.Tensor, probs: torch.Tensor,
top_ps: torch.Tensor, top_ps: torch.Tensor,
...@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List ...@@ -259,3 +224,44 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
output_token_ids_logprobs_idx.append([]) output_token_ids_logprobs_idx.append([])
return output_token_ids_logprobs_val, output_token_ids_logprobs_idx return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
def apply_custom_logit_processor(
logits: torch.Tensor,
sampling_batch_info: SamplingBatchInfo,
num_tokens_in_batch: int = 1,
):
"""Apply custom logit processors to the logits.
This function will modify the logits in-place.
num_tokens_in_batch is needed to support spec decoding, where each batch can contain multiple
tokens. By default, we assume each batch contains only 1 token.
"""
assert logits.shape[0] == len(sampling_batch_info) * num_tokens_in_batch, (
f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
f"sampling_batch_info ({len(sampling_batch_info)}) x num_tokens_in_batch "
f"({num_tokens_in_batch})"
)
for _, (
processor,
batch_mask,
) in sampling_batch_info.custom_logit_processor.items():
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
assert batch_mask.shape[0] == len(sampling_batch_info), (
f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
f"sampling_batch_info ({len(sampling_batch_info)})"
)
batch_mask = torch.repeat_interleave(batch_mask, num_tokens_in_batch)
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
[sampling_batch_info.custom_params[i] for i in batch_indices],
)
logger.debug(
f"Custom logit processor {processor.__class__.__name__} is applied."
)
...@@ -852,7 +852,7 @@ class TokenizerManager: ...@@ -852,7 +852,7 @@ class TokenizerManager:
obj.load_format = self.server_args.load_format obj.load_format = self.server_args.load_format
logger.info("Start update_weights. Load format=%s", obj.load_format) logger.info("Start update_weights. Load format=%s", obj.load_format)
if True: if True: # Keep this redundant check to simplify some internal code sync
# Hold the lock if it is not async. This means that weight sync # Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress. # cannot run while requests are in progress.
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
import logging import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
......
...@@ -9,10 +9,12 @@ import torch ...@@ -9,10 +9,12 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -27,6 +29,12 @@ class SamplingBatchInfo: ...@@ -27,6 +29,12 @@ class SamplingBatchInfo:
# Whether all requests use greedy sampling # Whether all requests use greedy sampling
is_all_greedy: bool is_all_greedy: bool
# Whether any requests use top_p sampling
need_top_p_sampling: bool
# Whether any requests use top_k sampling
need_top_k_sampling: bool
# Whether any request needs min_p sampling # Whether any request needs min_p sampling
need_min_p_sampling: bool need_min_p_sampling: bool
...@@ -133,6 +141,8 @@ class SamplingBatchInfo: ...@@ -133,6 +141,8 @@ class SamplingBatchInfo:
top_ks=top_ks, top_ks=top_ks,
min_ps=min_ps, min_ps=min_ps,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_top_p_sampling=any(r.sampling_params.top_p != 1.0 for r in reqs),
need_top_k_sampling=any(r.sampling_params.top_k != TOP_K_ALL for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size, vocab_size=vocab_size,
penalizer_orchestrator=penalizer_orchestrator, penalizer_orchestrator=penalizer_orchestrator,
...@@ -167,7 +177,7 @@ class SamplingBatchInfo: ...@@ -167,7 +177,7 @@ class SamplingBatchInfo:
# Apply the mask # Apply the mask
for i, grammar in enumerate(self.grammars): for i, grammar in enumerate(self.grammars):
if grammar and not grammar.finished: if grammar and not grammar.finished and not grammar.is_terminated():
grammar.fill_vocab_mask(self.vocab_mask, i) grammar.fill_vocab_mask(self.vocab_mask, i)
# Move the mask to the device if needed # Move the mask to the device if needed
...@@ -308,4 +318,6 @@ class SamplingBatchInfo: ...@@ -308,4 +318,6 @@ class SamplingBatchInfo:
setattr(self, item, torch.cat([self_val, other_val])) setattr(self, item, torch.cat([self_val, other_val]))
self.is_all_greedy &= other.is_all_greedy self.is_all_greedy &= other.is_all_greedy
self.need_top_p_sampling |= other.need_top_p_sampling
self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling self.need_min_p_sampling |= other.need_min_p_sampling
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
_SAMPLING_EPS = 1e-6 _SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
class SamplingParams: class SamplingParams:
...@@ -84,7 +85,7 @@ class SamplingParams: ...@@ -84,7 +85,7 @@ class SamplingParams:
self.temperature = 1.0 self.temperature = 1.0
self.top_k = 1 self.top_k = 1
if self.top_k == -1: if self.top_k == -1:
self.top_k = 1 << 30 # whole vocabulary self.top_k = TOP_K_ALL # whole vocabulary
def verify(self): def verify(self):
if self.temperature < 0.0: if self.temperature < 0.0:
......
import dataclasses import dataclasses
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence from typing import Dict, List, Optional, Sequence
import torch import torch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
CommunicateContext, CommunicateContext,
CommunicateSimpleFn,
CommunicateSummableTensorPairFn, CommunicateSummableTensorPairFn,
ScatterMode, ScatterMode,
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms from sglang.srt.layers.quantization.deep_gemm import configure_deep_gemm_num_sms
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -20,9 +18,6 @@ from sglang.srt.operations import execute_operations, execute_overlapped_operati ...@@ -20,9 +18,6 @@ from sglang.srt.operations import execute_operations, execute_overlapped_operati
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG") _tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -46,7 +41,7 @@ def compute_split_seq_index( ...@@ -46,7 +41,7 @@ def compute_split_seq_index(
assert num_tokens == 0 assert num_tokens == 0
return 0 return 0
else: else:
raise NotImplementedError raise NotImplementedError()
def _split_array_by_half_sum(arr: Sequence[int]) -> int: def _split_array_by_half_sum(arr: Sequence[int]) -> int:
......
...@@ -1928,16 +1928,18 @@ def next_power_of_2(n: int): ...@@ -1928,16 +1928,18 @@ def next_power_of_2(n: int):
setattr(triton, "next_power_of_2", next_power_of_2) setattr(triton, "next_power_of_2", next_power_of_2)
@contextmanager class EmptyContextManager:
def empty_context(*args, **kwargs): def __enter__(self):
try: return self
# Setup code goes here
yield def __exit__(self, exc_type, exc_value, traceback):
finally:
# Cleanup code goes here
pass pass
def empty_context(*args, **kwargs):
return EmptyContextManager()
def add_prefix(name: str, prefix: str) -> str: def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name. """Add a weight path prefix to a module name.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment