Unverified Commit 2ac46e94 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Sync changes on io_struct.py and deterministic ops (#11498)

parent 0aa65f94
......@@ -321,7 +321,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--debug-tensor-dump-output-folder` | The output folder for debug tensor dumps. | None |
| `--debug-tensor-dump-input-file` | The input file for debug tensor dumps. | None |
| `--debug-tensor-dump-inject` | Enable injection of debug tensor dumps. | False |
| `--debug-tensor-dump-prefill-only` | Enable prefill-only mode for debug tensor dumps. | False |
## PD disaggregation
......
......@@ -240,6 +240,7 @@ class GroupCoordinator:
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
torch_compile: Optional[bool] = None,
gloo_timeout: timedelta = timedelta(seconds=120 * 60),
):
# Set group info
group_name = group_name or "anonymous"
......@@ -259,7 +260,9 @@ class GroupCoordinator:
)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
cpu_group = torch.distributed.new_group(
ranks, backend="gloo", timeout=gloo_timeout
)
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
......
......@@ -91,7 +91,6 @@ class Sampler(nn.Module):
batch_next_token_ids = torch.argmax(logits, -1)
if return_logprob:
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
else:
# If requested, cache probabilities from original logits before temperature scaling.
if return_logprob and RETURN_ORIGINAL_LOGPROB:
......@@ -288,21 +287,29 @@ def multinomial_with_seed(
"""
n, m = inputs.shape
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
step_seed = seed * 19349663 ^ positions * 73856093
step_seed = (seed * 19349663) ^ (positions * 73856093)
seed_expanded = step_seed.unsqueeze(-1)
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
hashed = (seed_expanded * 8589934591) ^ (col_indices * 479001599)
uniform_samples = (hashed % (2**24)).float() / (2**24)
epsilon = 1e-9
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
epsilon = 1e-10
uniform_samples = uniform_samples.clamp(epsilon, 1.0 - epsilon)
gumbel_noise = -torch.log(-torch.log(uniform_samples))
log_probs = torch.log(inputs + epsilon)
perturbed_log_probs = log_probs + gumbel_noise
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
def sampling_from_probs_torch(probs: torch.Tensor):
def sampling_from_probs_torch(
probs: torch.Tensor,
sampling_seed: Optional[torch.Tensor] = None,
positions: Optional[torch.Tensor] = None,
):
"""A sampling implementation with native pytorch operations, without
top-k, top-p, or min-p filtering."""
sampled_index = torch.multinomial(probs, num_samples=1)
if sampling_seed is not None:
sampled_index = multinomial_with_seed(probs, sampling_seed, positions)
else:
sampled_index = torch.multinomial(probs, num_samples=1)
batch_next_token_ids = sampled_index.view(-1).to(torch.int32)
return batch_next_token_ids
......
......@@ -245,9 +245,11 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
input_token_ids_logprobs_idx=recv_obj.input_token_ids_logprobs_idx,
output_token_ids_logprobs_val=recv_obj.output_token_ids_logprobs_val,
output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx,
output_token_entropy_val=recv_obj.output_token_entropy_val,
output_hidden_states=recv_obj.output_hidden_states,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
token_steps=recv_obj.token_steps,
)
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
......
......@@ -170,6 +170,9 @@ class GenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation
return_bytes: bool = False
# Whether to return entropy
return_entropy: bool = False
def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
......@@ -568,6 +571,7 @@ class GenerateReqInput(BaseReq):
no_logs=self.no_logs,
custom_labels=self.custom_labels,
return_bytes=self.return_bytes,
return_entropy=self.return_entropy,
)
......@@ -633,6 +637,9 @@ class TokenizedGenerateReqInput(BaseReq):
# (Internal) Whether to return bytes for image generation
return_bytes: bool = False
# Whether to return entropy
return_entropy: bool = False
@dataclass
class BatchTokenizedGenerateReqInput(BaseBatchReq):
......@@ -830,6 +837,7 @@ class BatchTokenIDOutput(BaseBatchReq):
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
output_token_entropy_val: List[float]
# Hidden states
output_hidden_states: List[List[float]]
......@@ -840,6 +848,9 @@ class BatchTokenIDOutput(BaseBatchReq):
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass
class BatchMultimodalDecodeReq(BaseBatchReq):
......@@ -861,11 +872,14 @@ class BatchMultimodalDecodeReq(BaseBatchReq):
completion_tokens: List[int]
cached_tokens: List[int]
# Placeholder token info
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
return_bytes: bool = False
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass
......@@ -896,13 +910,20 @@ class BatchStrOutput(BaseBatchReq):
input_token_ids_logprobs_idx: List[List]
output_token_ids_logprobs_val: List[List]
output_token_ids_logprobs_idx: List[List]
output_token_entropy_val: List[float]
# Hidden states
output_hidden_states: List[List[float]]
# The information of placeholder tokens (e.g., image token)
# idx is the index of the token in the prompt after expansion.
# val is the length of padded tokens after expansion.
placeholder_tokens_idx: List[Optional[List[int]]]
placeholder_tokens_val: List[Optional[List[int]]]
# The trainer step id. Used to know which step's weights are used for sampling.
token_steps: List[List[int]] = None
@dataclass
class BatchMultimodalOutput(BaseBatchReq):
......@@ -979,6 +1000,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
torch_empty_cache: bool = False
# Whether to keep the scheduler paused after weight update
keep_pause: bool = False
# The trainer step id. Used to know which step's weights are used for sampling.
token_step: int = 0
@dataclass
......@@ -1416,6 +1439,16 @@ class WatchLoadUpdateReq(BaseReq):
loads: List[GetLoadReqOutput]
@dataclass
class LazyDumpTensorsReqInput(BaseReq):
pass
@dataclass
class LazyDumpTensorsReqOutput(BaseReq):
success: bool
def _check_all_req_types():
"""A helper function to check all request types are defined in this file."""
import inspect
......
......@@ -190,6 +190,11 @@ def _handle_output_by_index(output, i):
if output.output_token_ids_logprobs_idx
else None
),
output_token_entropy_val=(
[output.output_token_entropy_val[i]]
if output.output_token_entropy_val
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
......@@ -197,6 +202,7 @@ def _handle_output_by_index(output, i):
),
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
token_steps=([output.token_steps[i]] if output.token_steps else None),
)
elif isinstance(output, BatchEmbeddingOutput):
new_output = BatchEmbeddingOutput(
......@@ -306,6 +312,11 @@ def _handle_output_by_index(output, i):
if output.output_token_ids_logprobs_idx
else None
),
output_token_entropy_val=(
[output.output_token_entropy_val[i]]
if output.output_token_entropy_val
else None
),
output_hidden_states=(
[output.output_hidden_states[i]]
if output.output_hidden_states
......@@ -313,6 +324,7 @@ def _handle_output_by_index(output, i):
),
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
token_steps=([output.token_steps[i]] if output.token_steps else None),
)
elif isinstance(output, BatchMultimodalOutput):
new_output = BatchMultimodalOutput(
......
......@@ -920,7 +920,8 @@ class SchedulerOutputProcessorMixin:
input_token_ids_logprobs_idx,
output_token_ids_logprobs_val,
output_token_ids_logprobs_idx,
output_hidden_states,
output_token_entropy_val=None,
output_hidden_states=output_hidden_states,
rids=rids,
placeholder_tokens_idx=None,
placeholder_tokens_val=None,
......
......@@ -73,9 +73,6 @@ logger = logging.getLogger(__name__)
# Dump tensors for debugging
debug_tensor_dump_output_folder = None
debug_tensor_dump_prefill_only = False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs = False
debug_tensor_dump_inject = False
debug_tensor_dump_layers = None
debug_tensor_dump_test = False
......
......@@ -455,7 +455,6 @@ class ServerArgs:
debug_tensor_dump_output_folder: Optional[str] = None
debug_tensor_dump_input_file: Optional[str] = None
debug_tensor_dump_inject: bool = False
debug_tensor_dump_prefill_only: bool = False
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
disaggregation_mode: Literal["null", "prefill", "decode"] = "null"
......@@ -2831,11 +2830,6 @@ class ServerArgs:
default=ServerArgs.debug_tensor_dump_inject,
help="Inject the outputs from jax as the input of every layer.",
)
parser.add_argument(
"--debug-tensor-dump-prefill-only",
action="store_true",
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
)
parser.add_argument(
"--enable-dynamic-batch-tokenizer",
action="store_true",
......
......@@ -34,7 +34,7 @@ def get_model_config(tp_size: int):
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"dtype": config.dtype,
"block_shape": config.quantization_config["weight_block_size"],
}
......
import time
import unittest
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_deterministic import BenchArgs, test_deterministic
from sglang.test.test_utils import (
......@@ -55,6 +52,7 @@ class TestDeterministicBase(CustomTestCase):
args.n_start = 10
args.n_trials = 20
results = test_deterministic(args)
args.temperature = 0.5 # test for deterministic sampling
for result in results:
assert result == 1
......@@ -65,6 +63,7 @@ class TestDeterministicBase(CustomTestCase):
args.test_mode = "mixed"
args.n_start = 10
args.n_trials = 20
args.temperature = 0.5 # test for deterministic sampling
results = test_deterministic(args)
for result in results:
assert result == 1
......@@ -76,6 +75,7 @@ class TestDeterministicBase(CustomTestCase):
args.test_mode = "prefix"
args.n_start = 10
args.n_trials = 10
args.temperature = 0.5 # test for deterministic sampling
results = test_deterministic(args)
for result in results:
assert result == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment