Unverified Commit 2ac46e94 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Sync changes on io_struct.py and deterministic ops (#11498)

parent 0aa65f94
...@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None | | `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None |
| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None | | `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None |
| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False | | `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False |
| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False |
## PD disaggregation ## PD disaggregation
......
...@@ -240,6 +240,7 @@ class GroupCoordinator: ...@@ -240,6 +240,7 @@ class GroupCoordinator:
use_message_queue_broadcaster: bool = False, use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None, group_name: Optional[str] = None,
torch_compile: Optional[bool] = None, torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
): ):
# Set group info # Set group info
group_name = group_name or "anonymous" group_name = group_name or "anonymous"
...@@ -259,7 +260,9 @@ class GroupCoordinator: ...@@ -259,7 +260,9 @@ class GroupCoordinator:
) )
# a group with `gloo` backend, to allow direct coordination between # a group with `gloo` backend, to allow direct coordination between
# processes through the CPU. # processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo") cpu_group = torch.distributed.new_group(
ranks, backend="gloo", timeout=gloo_timeout
)
if self.rank in ranks: if self.rank in ranks:
self.ranks = ranks self.ranks = ranks
self.world_size = len(ranks) self.world_size = len(ranks)
......
...@@ -91,7 +91,6 @@ class Sampler(nn.Module): ...@@ -91,7 +91,6 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1) batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob: if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1) logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else: else:
# If requested, cache probabilities from original logits before temperature scaling. # If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB: if return_logprob and RETURN_ORIGINAL_LOGPROB:
...@@ -288,21 +287,29 @@ def multinomial_with_seed( ...@@ -288,21 +287,29 @@ def multinomial_with_seed(
""" """
n, m = inputs.shape n, m = inputs.shape
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0) col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
step_seed = seed * 19349663 ^ positions * 73856093 step_seed = (seed * 19349663) ^ (positions * 73856093)
seed_expanded = step_seed.unsqueeze(-1) seed_expanded = step_seed.unsqueeze(-1)
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599 hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
uniform_samples = (hashed % (2**24)).float() / (2**24) uniform_samples = (hashed % (2**24)).float() / (2**24)
epsilon = 1e-9 epsilon = 1e-10
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon) uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
gumbel_noise = -torch.log(-torch.log(uniform_samples))
log_probs = torch.log(inputs + epsilon) log_probs = torch.log(inputs + epsilon)
perturbed_log_probs = log_probs + gumbel_noise perturbed_log_probs = log_probs + gumbel_noise
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True) return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
def sampling_from_probs_torch(probs: torch.Tensor): def sampling_from_probs_torch(
probs: torch.Tensor,
sampling_seed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
):
"""A sampling implementation with native pytorch operations, without """A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering.""" top-k, top-p, or min-p filtering."""
sampled_index = torch.multinomial(probs, num_samples=1) if sampling_seed is not None:
sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
else:
sampled_index = torch.multinomial(probs, num_samples=1)
batch_next_token_ids = sampled_index.view(-1).to(torch.int32) batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
return batch_next_token_ids return batch_next_token_ids
......
...@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): ...@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx, input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val, output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states, output_hidden_states=recv_obj.output_hidden_states,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
token_steps=recv_obj.token_steps,
) )
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq): def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
......
...@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq): ...@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation # (Internal) Whether to return bytes for image generation
return_bytes: bool = False return_bytes: bool = False
# Whether to return entropy
return_entropy: bool = False
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)
...@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq): ...@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
no_logs=self.no_logs, no_logs=self.no_logs,
custom_labels=self.custom_labels, custom_labels=self.custom_labels,
return_bytes=self.return_bytes, return_bytes=self.return_bytes,
return_entropy=self.return_entropy,
) )
...@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq): ...@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation # (Internal) Whether to return bytes for image generation
return_bytes: bool = False return_bytes: bool = False
# Whether to return entropy
return_entropy: bool = False
@dataclass @dataclass
class BatchTokenizedGenerateReqInput(BaseBatchReq): class BatchTokenizedGenerateReqInput(BaseBatchReq):
...@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq): ...@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
input_token_ids_logprobs_idx: List[List] input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List] output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List] output_token_ids_logprobs_idx: List[List]
output_token_entropy_val: List[float]
# Hidden states # Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
...@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq): ...@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass @dataclass
class BatchMultimodalDecodeReq(BaseBatchReq): class BatchMultimodalDecodeReq(BaseBatchReq):
...@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq): ...@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
completion_tokens: List[int] completion_tokens: List[int]
cached_tokens: List[int] cached_tokens: List[int]
# Placeholder token info # The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
return_bytes: bool = False # The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass @dataclass
...@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq): ...@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
input_token_ids_logprobs_idx: List[List] input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List] output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List] output_token_ids_logprobs_idx: List[List]
output_token_entropy_val: List[float]
# Hidden states # Hidden states
output_hidden_states: List[List[float]] output_hidden_states: List[List[float]]
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx: List[Optional[List[int]]] placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]] placeholder_tokens_val: List[Optional[List[int]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass @dataclass
class BatchMultimodalOutput(BaseBatchReq): class BatchMultimodalOutput(BaseBatchReq):
...@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq): ...@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
torch_empty_cache: bool = False torch_empty_cache: bool = False
# Whether to keep the scheduler paused after weight update # Whether to keep the scheduler paused after weight update
keep_pause: bool = False keep_pause: bool = False
# The trainer step id. Used to know which step's weights are used for sampling.
token_step: int = 0
@dataclass @dataclass
...@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq): ...@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
loads: List[GetLoadReqOutput] loads: List[GetLoadReqOutput]
@dataclass
class LazyDumpTensorsReqInput(BaseReq):
pass
@dataclass
class LazyDumpTensorsReqOutput(BaseReq):
success: bool
def _check_all_req_types(): def _check_all_req_types():
"""A helper function to check all request types are defined in this file.""" """A helper function to check all request types are defined in this file."""
import inspect import inspect
......
...@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i): ...@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
if output.output_token_ids_logprobs_idx if output.output_token_ids_logprobs_idx
else None else None
), ),
output_token_entropy_val=(
[output.output_token_entropy_val[i]]
if output.output_token_entropy_val
else None
),
output_hidden_states=( output_hidden_states=(
[output.output_hidden_states[i]] [output.output_hidden_states[i]]
if output.output_hidden_states if output.output_hidden_states
...@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i): ...@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
), ),
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
token_steps=([output.token_steps[i]] if output.token_steps else None),
) )
elif isinstance(output, BatchEmbeddingOutput): elif isinstance(output, BatchEmbeddingOutput):
new_output = BatchEmbeddingOutput( new_output = BatchEmbeddingOutput(
...@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i): ...@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
if output.output_token_ids_logprobs_idx if output.output_token_ids_logprobs_idx
else None else None
), ),
output_token_entropy_val=(
[output.output_token_entropy_val[i]]
if output.output_token_entropy_val
else None
),
output_hidden_states=( output_hidden_states=(
[output.output_hidden_states[i]] [output.output_hidden_states[i]]
if output.output_hidden_states if output.output_hidden_states
...@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i): ...@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
), ),
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
token_steps=([output.token_steps[i]] if output.token_steps else None),
) )
elif isinstance(output, BatchMultimodalOutput): elif isinstance(output, BatchMultimodalOutput):
new_output = BatchMultimodalOutput( new_output = BatchMultimodalOutput(
......
...@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin: ...@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
input_token_ids_logprobs_idx, input_token_ids_logprobs_idx,
output_token_ids_logprobs_val, output_token_ids_logprobs_val,
output_token_ids_logprobs_idx, output_token_ids_logprobs_idx,
output_hidden_states, output_token_entropy_val=None,
output_hidden_states=output_hidden_states,
rids=rids, rids=rids,
placeholder_tokens_idx=None, placeholder_tokens_idx=None,
placeholder_tokens_val=None, placeholder_tokens_val=None,
......
...@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__) ...@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
# Dump tensors for debugging # Dump tensors for debugging
debug_tensor_dump_output_folder = None debug_tensor_dump_output_folder = None
debug_tensor_dump_prefill_only = False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs = False
debug_tensor_dump_inject = False debug_tensor_dump_inject = False
debug_tensor_dump_layers = None debug_tensor_dump_layers = None
debug_tensor_dump_test = False debug_tensor_dump_test = False
......
...@@ -455,7 +455,6 @@ class ServerArgs: ...@@ -455,7 +455,6 @@ class ServerArgs:
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
debug_tensor_dump_input_file: Optional[str] = None debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False debug_tensor_dump_inject: bool = False
debug_tensor_dump_prefill_only: bool = False
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only) # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: Literal["null", "prefill", "decode"] = "null" disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
...@@ -2831,11 +2830,6 @@ class ServerArgs: ...@@ -2831,11 +2830,6 @@ class ServerArgs:
default=ServerArgs.debug_tensor_dump_inject, default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.", help="Inject the outputs from jax as the input of every layer.",
) )
parser.add_argument(
"--debug-tensor-dump-prefill-only",
action="store_true",
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
)
parser.add_argument( parser.add_argument(
"--enable-dynamic-batch-tokenizer", "--enable-dynamic-batch-tokenizer",
action="store_true", action="store_true",
......
...@@ -34,7 +34,7 @@ def get_model_config(tp_size: int): ...@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
"topk": topk, "topk": topk,
"hidden_size": config.hidden_size, "hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size, "shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype, "dtype": config.dtype,
"block_shape": config.quantization_config["weight_block_size"], "block_shape": config.quantization_config["weight_block_size"],
} }
......
import time
import unittest import unittest
import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_deterministic import BenchArgs, test_deterministic from sglang.test.test_deterministic import BenchArgs, test_deterministic
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase): ...@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
args.n_start = 10 args.n_start = 10
args.n_trials = 20 args.n_trials = 20
results = test_deterministic(args) results = test_deterministic(args)
args.temperature = 0.5 # test for deterministic sampling
for result in results: for result in results:
assert result == 1 assert result == 1
...@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase): ...@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
args.test_mode = "mixed" args.test_mode = "mixed"
args.n_start = 10 args.n_start = 10
args.n_trials = 20 args.n_trials = 20
args.temperature = 0.5 # test for deterministic sampling
results = test_deterministic(args) results = test_deterministic(args)
for result in results: for result in results:
assert result == 1 assert result == 1
...@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase): ...@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
args.test_mode = "prefix" args.test_mode = "prefix"
args.n_start = 10 args.n_start = 10
args.n_trials = 10 args.n_trials = 10
args.temperature = 0.5 # test for deterministic sampling
results = test_deterministic(args) results = test_deterministic(args)
for result in results: for result in results:
assert result == 1 assert result == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment