Unverified Commit b1104538 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify logits penalizer (#2086)

parent 3b44bbee
...@@ -1019,7 +1019,7 @@ class ScheduleBatch: ...@@ -1019,7 +1019,7 @@ class ScheduleBatch:
extend_prefix_lens = self.prefix_lens extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
if self.sampling_info is not None: if self.sampling_info:
if self.has_grammar: if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs] self.sampling_info.grammars = [req.grammar for req in self.reqs]
else: else:
...@@ -1063,6 +1063,7 @@ class ScheduleBatch: ...@@ -1063,6 +1063,7 @@ class ScheduleBatch:
out_cache_loc=self.out_cache_loc, out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs, decoding_reqs=self.decoding_reqs,
sampling_info=dataclasses.replace(self.sampling_info),
) )
def __str__(self): def __str__(self):
...@@ -1122,20 +1123,6 @@ class ModelWorkerBatch: ...@@ -1122,20 +1123,6 @@ class ModelWorkerBatch:
# Sampling info # Sampling info
sampling_info: SamplingBatchInfo sampling_info: SamplingBatchInfo
def copy(self):
return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
def to(self, device: str):
self.input_ids = self.input_ids.to(device, non_blocking=True)
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
self.req_to_token_pool_records = [
(x, y.to(device, non_blocking=True))
for x, y in self.req_to_token_pool_records
]
self.sampling_info.to(device)
@triton.jit @triton.jit
def write_req_to_token_pool_triton( def write_req_to_token_pool_triton(
......
...@@ -931,14 +931,14 @@ class Scheduler: ...@@ -931,14 +931,14 @@ class Scheduler:
# Check finish conditions # Check finish conditions
logprob_pt = 0 logprob_pt = 0
for i, req in enumerate(batch.reqs): for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if req.is_retracted: if req.is_retracted:
continue continue
if req.is_being_chunked <= 0: if req.is_being_chunked <= 0:
# Inflight reqs' prefill is not finished # Inflight reqs' prefill is not finished
req.completion_tokens_wo_jump_forward += 1 req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i]) req.output_ids.append(next_token_id)
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
...@@ -947,7 +947,7 @@ class Scheduler: ...@@ -947,7 +947,7 @@ class Scheduler:
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req.grammar is not None: if req.grammar is not None:
req.grammar.accept_token(next_token_ids[i]) req.grammar.accept_token(next_token_id)
if req.return_logprob: if req.return_logprob:
logprob_pt += self.add_logprob_return_values( logprob_pt += self.add_logprob_return_values(
......
...@@ -16,6 +16,7 @@ limitations under the License. ...@@ -16,6 +16,7 @@ limitations under the License.
"""A tensor parallel worker.""" """A tensor parallel worker."""
import logging import logging
import threading
from typing import Optional from typing import Optional
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
...@@ -138,9 +139,15 @@ class TpModelWorker: ...@@ -138,9 +139,15 @@ class TpModelWorker:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
self.model_runner.forward(forward_batch) self.model_runner.forward(forward_batch)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
launch_event: Optional[threading.Event] = None,
):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
if launch_event:
launch_event.set()
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids return logits_output, next_token_ids
......
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
"""A tensor parallel worker.""" """A tensor parallel worker."""
import dataclasses
import logging import logging
import threading import threading
import time import time
...@@ -107,7 +108,7 @@ class TpModelWorkerClient: ...@@ -107,7 +108,7 @@ class TpModelWorkerClient:
# Run forward # Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch model_worker_batch, self.launch_event
) )
# Update the future token ids map # Update the future token ids map
...@@ -134,7 +135,6 @@ class TpModelWorkerClient: ...@@ -134,7 +135,6 @@ class TpModelWorkerClient:
next_token_ids = next_token_ids.to("cpu", non_blocking=True) next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event.record() copy_event.record()
self.launch_event.set()
self.output_queue.put((copy_event, logits_output, next_token_ids)) self.output_queue.put((copy_event, logits_output, next_token_ids))
def resolve_batch_result(self, bid: int): def resolve_batch_result(self, bid: int):
...@@ -159,7 +159,10 @@ class TpModelWorkerClient: ...@@ -159,7 +159,10 @@ class TpModelWorkerClient:
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Push a new batch to the queue # Push a new batch to the queue
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct)) model_worker_batch.sampling_info = dataclasses.replace(
model_worker_batch.sampling_info
)
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
# Allocate output future objects # Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
......
import abc import abc
import dataclasses import dataclasses
import typing from typing import List, Set, Type, Union
import torch import torch
@dataclasses.dataclass @dataclasses.dataclass
class _ReqLike: class _ReqLike:
origin_input_ids: typing.Union[torch.Tensor, typing.List[int]] origin_input_ids: List[int]
@dataclasses.dataclass @dataclasses.dataclass
class _BatchLike: class _BatchLike:
reqs: typing.List[_ReqLike] reqs: List[_ReqLike]
def batch_size(self): def batch_size(self):
return len(self.reqs) return len(self.reqs)
class BatchedPenalizerOrchestrator: class BatchedPenalizerOrchestrator:
batch: _BatchLike
device: str
vocab_size: int
penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
batch: _BatchLike, batch: _BatchLike,
device: str, device: str,
Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]], Penalizers: Set[Type["_BatchedPenalizer"]],
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.batch = batch self.batch = batch
self.device = device self.device = device
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers} self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
is_required = False is_required = False
...@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator: ...@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
is_required |= pen_is_required is_required |= pen_is_required
self.is_required = is_required self.is_required = is_required
input_ids = [
torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
for req in self.reqs()
]
if self.is_required: if self.is_required:
self.cumulate_input_tokens( self.cumulate_input_tokens(input_ids=input_ids)
input_ids=[req.origin_input_ids for req in self.reqs()]
)
def reqs(self): def reqs(self):
return self.batch.reqs return self.batch.reqs
...@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator: ...@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
def batch_size(self): def batch_size(self):
return self.batch.batch_size() return self.batch.batch_size()
def cumulate_input_tokens( def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
self,
input_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
""" """
Feed the input tokens to the penalizers. Feed the input tokens to the penalizers.
Args: Args:
input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens. input_ids (List[torch.Tensor]): The input tokens.
""" """
token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids) token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
for penalizer in self.penalizers.values(): for penalizer in self.penalizers.values():
penalizer.cumulate_input_tokens(input_ids=token_ids) penalizer.cumulate_input_tokens(input_ids=token_ids)
def cumulate_output_tokens( def cumulate_output_tokens(self, output_ids: torch.Tensor):
self,
output_ids: typing.Union[
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
):
""" """
Feed the output tokens to the penalizers. Feed the output tokens to the penalizers.
Args: Args:
output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens. output_ids (torch.Tensor): The output tokens.
""" """
if not self.is_required: if not self.is_required:
return return
...@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator: ...@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
def filter( def filter(
self, self,
indices_to_keep: typing.List[int], indices_to_keep: List[int],
indices_tensor_to_keep: torch.Tensor = None, indices_tensor_to_keep: torch.Tensor = None,
): ):
""" """
Filter the penalizers based on the indices to keep in the batch. Filter the penalizers based on the indices to keep in the batch.
Args: Args:
indices_to_keep (typing.List[int]): List of indices to keep in the batch. indices_to_keep (List[int]): List of indices to keep in the batch.
indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor. indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
""" """
if not self.is_required: if not self.is_required:
...@@ -174,32 +160,18 @@ class _TokenIDs: ...@@ -174,32 +160,18 @@ class _TokenIDs:
Attributes: Attributes:
orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to. orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs. token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
cached_counts (torch.Tensor): The cached occurrence count tensor. cached_counts (torch.Tensor): The cached occurrence count tensor.
""" """
orchestrator: BatchedPenalizerOrchestrator
token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
cached_counts: torch.Tensor = None
def __init__( def __init__(
self, self,
orchestrator: BatchedPenalizerOrchestrator, orchestrator: BatchedPenalizerOrchestrator,
token_ids: typing.Union[ token_ids: Union[torch.Tensor, List[torch.Tensor]],
typing.List[torch.Tensor], typing.List[typing.List[int]]
],
): ):
self.orchestrator = orchestrator self.orchestrator = orchestrator
if not isinstance(token_ids[0], torch.Tensor):
token_ids = [
torch.tensor(
data=ids, dtype=torch.int64, device=self.orchestrator.device
)
for ids in token_ids
]
self.token_ids = token_ids self.token_ids = token_ids
self.cached_counts = None
def occurrence_count(self) -> torch.Tensor: def occurrence_count(self) -> torch.Tensor:
""" """
...@@ -213,19 +185,13 @@ class _TokenIDs: ...@@ -213,19 +185,13 @@ class _TokenIDs:
token_ids = self.token_ids token_ids = self.token_ids
if isinstance(token_ids, torch.Tensor): if isinstance(token_ids, list):
token_ids = token_ids.unsqueeze(1) # TODO: optimize this part
# needs to be long to be used as index in scatter_add
if token_ids.dtype != torch.int64:
token_ids = token_ids.to(torch.int64)
padded_token_ids = torch.nn.utils.rnn.pad_sequence( padded_token_ids = torch.nn.utils.rnn.pad_sequence(
sequences=token_ids, sequences=token_ids,
batch_first=True, batch_first=True,
padding_value=self.orchestrator.vocab_size, padding_value=self.orchestrator.vocab_size,
) )
self.cached_counts = torch.zeros( self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1), size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
dtype=torch.int64, dtype=torch.int64,
...@@ -237,6 +203,16 @@ class _TokenIDs: ...@@ -237,6 +203,16 @@ class _TokenIDs:
)[ )[
:, : self.orchestrator.vocab_size :, : self.orchestrator.vocab_size
] ]
else:
# TODO: optimize this part. We do not need to create this big tensor every time.
# We can directly apply the results on the logits.
self.cached_counts = torch.zeros(
size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
device=self.orchestrator.device,
)
self.cached_counts[
torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
] = 1
return self.cached_counts return self.cached_counts
...@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC): ...@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
An abstract class for a batched penalizer. An abstract class for a batched penalizer.
""" """
orchestrator: BatchedPenalizerOrchestrator
_is_prepared: bool = False
def __init__(self, orchestrator: BatchedPenalizerOrchestrator): def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
self.orchestrator = orchestrator self.orchestrator = orchestrator
self._is_prepared = False
def is_prepared(self) -> bool: def is_prepared(self) -> bool:
return self._is_prepared return self._is_prepared
...@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC): ...@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
return self._apply(logits=logits) return self._apply(logits=logits)
def filter( def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
if not self.is_prepared(): if not self.is_prepared():
return return
...@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC): ...@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def _filter( def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
""" """
Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch. Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
""" """
......
import typing from typing import List
import torch import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedFrequencyPenalizer(_BatchedPenalizer): class BatchedFrequencyPenalizer(_BatchedPenalizer):
...@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): ...@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
) )
def _teardown(self): def _teardown(self):
del self.frequency_penalties
del self.cumulated_frequency_penalties
self.frequency_penalties = None self.frequency_penalties = None
self.cumulated_frequency_penalties = None self.cumulated_frequency_penalties = None
...@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer): ...@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
logits -= self.cumulated_frequency_penalties logits -= self.cumulated_frequency_penalties
return logits return logits
def _filter( def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep] self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[ self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
indices_tensor_to_keep indices_tensor_to_keep
......
import typing from typing import List
import torch import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedMinNewTokensPenalizer(_BatchedPenalizer): class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
...@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ...@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
) )
def _teardown(self): def _teardown(self):
del self.min_new_tokens
del self.stop_token_penalties
del self.len_output_tokens
self.min_new_tokens = None self.min_new_tokens = None
self.stop_token_penalties = None self.stop_token_penalties = None
self.len_output_tokens = None self.len_output_tokens = None
...@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer): ...@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
logits[mask] += self.stop_token_penalties[mask] logits[mask] += self.stop_token_penalties[mask]
return logits return logits
def _filter( def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep] self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep] self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep] self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
......
import typing from typing import List
import torch import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedPresencePenalizer(_BatchedPenalizer): class BatchedPresencePenalizer(_BatchedPenalizer):
...@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer): ...@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
) )
def _teardown(self): def _teardown(self):
del self.presence_penalties
del self.cumulated_presence_penalties
self.presence_penalties = None self.presence_penalties = None
self.cumulated_presence_penalties = None self.cumulated_presence_penalties = None
...@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer): ...@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
logits -= self.cumulated_presence_penalties logits -= self.cumulated_presence_penalties
return logits return logits
def _filter( def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.presence_penalties = self.presence_penalties[indices_tensor_to_keep] self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
self.cumulated_presence_penalties = self.cumulated_presence_penalties[ self.cumulated_presence_penalties = self.cumulated_presence_penalties[
indices_tensor_to_keep indices_tensor_to_keep
......
import typing from typing import List
import torch import torch
from ..orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
class BatchedRepetitionPenalizer(_BatchedPenalizer): class BatchedRepetitionPenalizer(_BatchedPenalizer):
...@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): ...@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
) )
def _teardown(self): def _teardown(self):
del self.repetition_penalties
del self.cumulated_repetition_penalties
self.repetition_penalties = None self.repetition_penalties = None
self.cumulated_repetition_penalties = None self.cumulated_repetition_penalties = None
...@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): ...@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
logits * self.cumulated_repetition_penalties, logits * self.cumulated_repetition_penalties,
) )
def _filter( def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[ self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
indices_tensor_to_keep indices_tensor_to_keep
......
...@@ -27,10 +27,10 @@ class SamplingBatchInfo: ...@@ -27,10 +27,10 @@ class SamplingBatchInfo:
# Bias Tensors # Bias Tensors
vocab_size: int vocab_size: int
grammars: Optional[List] = None
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: Optional[torch.Tensor] = None vocab_mask: Optional[torch.Tensor] = None
apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None apply_mask: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
grammars: Optional[List] = None
# Penalizer # Penalizer
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
...@@ -211,25 +211,3 @@ class SamplingBatchInfo: ...@@ -211,25 +211,3 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
def copy(self):
return SamplingBatchInfo(
temperatures=self.temperatures,
top_ps=self.top_ps,
top_ks=self.top_ks,
min_ps=self.min_ps,
is_all_greedy=self.is_all_greedy,
need_min_p_sampling=self.need_min_p_sampling,
vocab_size=self.vocab_size,
device=self.device,
)
def to(self, device: str):
for item in [
"temperatures",
"top_ps",
"top_ks",
"min_ps",
]:
value = getattr(self, item)
setattr(self, item, value.to(device, non_blocking=True))
...@@ -24,7 +24,6 @@ class SamplingParams: ...@@ -24,7 +24,6 @@ class SamplingParams:
def __init__( def __init__(
self, self,
max_new_tokens: int = 128, max_new_tokens: int = 128,
min_new_tokens: int = 0,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0, temperature: float = 1.0,
...@@ -34,6 +33,7 @@ class SamplingParams: ...@@ -34,6 +33,7 @@ class SamplingParams:
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
min_new_tokens: int = 0,
spaces_between_special_tokens: bool = True, spaces_between_special_tokens: bool = True,
regex: Optional[str] = None, regex: Optional[str] = None,
n: int = 1, n: int = 1,
......
...@@ -782,7 +782,7 @@ class PortArgs: ...@@ -782,7 +782,7 @@ class PortArgs:
@staticmethod @staticmethod
def init_new(server_args) -> "PortArgs": def init_new(server_args) -> "PortArgs":
port = server_args.port + 42 port = server_args.port + random.randint(100, 1000)
while True: while True:
if is_port_available(port): if is_port_available(port):
break break
......
import dataclasses import dataclasses
import enum import enum
import typing
import unittest import unittest
from typing import Dict, List, Optional, Set, Tuple, Type
import torch import torch
...@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import ( ...@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import (
class MockSamplingParams: class MockSamplingParams:
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
min_new_tokens: int = 0 min_new_tokens: int = 0
stop_token_ids: typing.List[int] = None stop_token_ids: List[int] = None
presence_penalty: float = 0.0 presence_penalty: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
...@@ -24,12 +24,12 @@ class MockSamplingParams: ...@@ -24,12 +24,12 @@ class MockSamplingParams:
@dataclasses.dataclass @dataclasses.dataclass
class MockTokenizer: class MockTokenizer:
eos_token_id: int eos_token_id: int
additional_stop_token_ids: typing.Optional[typing.List[int]] = None additional_stop_token_ids: Optional[List[int]] = None
@dataclasses.dataclass @dataclasses.dataclass
class MockReq: class MockReq:
origin_input_ids: typing.List[int] origin_input_ids: List[int]
sampling_params: MockSamplingParams sampling_params: MockSamplingParams
tokenizer: MockTokenizer tokenizer: MockTokenizer
...@@ -42,8 +42,8 @@ class StepType(enum.Enum): ...@@ -42,8 +42,8 @@ class StepType(enum.Enum):
@dataclasses.dataclass @dataclasses.dataclass
class Step: class Step:
type: StepType type: StepType
token_ids: typing.List[int] token_ids: List[int]
expected_tensors: typing.Dict[str, torch.Tensor] expected_tensors: Dict[str, torch.Tensor]
# assume initial logits are all 1 # assume initial logits are all 1
expected_logits: torch.Tensor expected_logits: torch.Tensor
...@@ -52,7 +52,7 @@ class Step: ...@@ -52,7 +52,7 @@ class Step:
class Subject: class Subject:
sampling_params: MockSamplingParams sampling_params: MockSamplingParams
# first step must be input, which will be converted to Req # first step must be input, which will be converted to Req
steps: typing.List[Step] steps: List[Step]
eos_token_id: int = -1 eos_token_id: int = -1
def __post_init__(self): def __post_init__(self):
...@@ -66,7 +66,7 @@ class Subject: ...@@ -66,7 +66,7 @@ class Subject:
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}" f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
) )
def tensor_keys(self, i: int = 0) -> typing.Set[str]: def tensor_keys(self, i: int = 0) -> Set[str]:
return set(self.steps[i].expected_tensors.keys()) return set(self.steps[i].expected_tensors.keys())
def to_req(self) -> MockReq: def to_req(self) -> MockReq:
...@@ -80,7 +80,7 @@ class Subject: ...@@ -80,7 +80,7 @@ class Subject:
@dataclasses.dataclass @dataclasses.dataclass
class Case: class Case:
enabled: bool enabled: bool
test_subjects: typing.List[Subject] test_subjects: List[Subject]
def __post_init__(self): def __post_init__(self):
# each test_subjects.steps should have the same expected_tensors.keys() # each test_subjects.steps should have the same expected_tensors.keys()
...@@ -90,12 +90,12 @@ class Case: ...@@ -90,12 +90,12 @@ class Case:
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}" f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
) )
def tensor_keys(self, i: int = 0) -> typing.List[str]: def tensor_keys(self, i: int = 0) -> List[str]:
return set(self.test_subjects[i].tensor_keys()) return set(self.test_subjects[i].tensor_keys())
class BaseBatchedPenalizerTest(unittest.TestCase): class BaseBatchedPenalizerTest(unittest.TestCase):
Penalizer: typing.Type[_BatchedPenalizer] Penalizer: Type[_BatchedPenalizer]
device = "cuda" device = "cuda"
vocab_size = 5 vocab_size = 5
...@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase): ...@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
""" """
return torch.tensor(data, **kwargs, device=self.device) return torch.tensor(data, **kwargs, device=self.device)
def create_test_subjects(self) -> typing.List[Subject]: def create_test_subjects(self) -> List[Subject]:
raise NotImplementedError() raise NotImplementedError()
def create_test_cases(self): def create_test_cases(self):
...@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase): ...@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
def _create_penalizer( def _create_penalizer(
self, case: Case self, case: Case
) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]: ) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
orchestrator = BatchedPenalizerOrchestrator( orchestrator = BatchedPenalizerOrchestrator(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]), batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
...@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase): ...@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
if i < len(subject.steps) if i < len(subject.steps)
] ]
inputs: typing.List[typing.List[int]] = [] inputs: List[List[int]] = []
outputs: typing.List[typing.List[int]] = [] outputs: List[List[int]] = []
for subject in filtered_subjects: for subject in filtered_subjects:
step = subject.steps[i] step = subject.steps[i]
if step.type == StepType.INPUT: if step.type == StepType.INPUT:
inputs.append(step.token_ids) raise NotImplementedError()
outputs.append([])
else: else:
inputs.append([]) inputs.append([])
outputs.append(step.token_ids) outputs.append(step.token_ids)
if any(inputs):
orchestrator.cumulate_input_tokens(inputs)
if any(outputs): if any(outputs):
orchestrator.cumulate_output_tokens(outputs) for j in range(max(len(x) for x in outputs)):
tmp_outputs = torch.tensor(
[x[j] for x in outputs],
dtype=torch.int32,
device=orchestrator.device,
)
orchestrator.cumulate_output_tokens(tmp_outputs)
if penalizer.is_required(): if penalizer.is_required():
self.assertTrue(penalizer.is_prepared()) self.assertTrue(penalizer.is_prepared())
......
"""
Usage:
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
"""
import unittest import unittest
import sglang as sgl import sglang as sgl
...@@ -68,7 +73,7 @@ class TestSRTBackend(unittest.TestCase): ...@@ -68,7 +73,7 @@ class TestSRTBackend(unittest.TestCase):
# Run twice to capture more bugs # Run twice to capture more bugs
for _ in range(2): for _ in range(2):
accuracy, latency = test_hellaswag_select() accuracy, latency = test_hellaswag_select()
assert accuracy > 0.71, f"{accuracy=}" self.assertGreater(accuracy, 0.71)
def test_gen_min_new_tokens(self): def test_gen_min_new_tokens(self):
test_gen_min_new_tokens() test_gen_min_new_tokens()
......
import typing
import unittest import unittest
from typing import List
import torch import torch
...@@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest): ...@@ -48,7 +48,11 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
), ),
Step( Step(
type=StepType.OUTPUT, type=StepType.OUTPUT,
token_ids=[1, 2, 2], token_ids=[
1,
2,
2,
], # This is the output ids of one request in three steps.
expected_tensors={ expected_tensors={
"frequency_penalties": self.tensor( "frequency_penalties": self.tensor(
[[frequency_penalty] * self.vocab_size], dtype=torch.float32 [[frequency_penalty] * self.vocab_size], dtype=torch.float32
...@@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest): ...@@ -76,7 +80,7 @@ class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
], ],
) )
def create_test_subjects(self) -> typing.List[Subject]: def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty) self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
self.disabled = self._create_subject(frequency_penalty=0.0) self.disabled = self._create_subject(frequency_penalty=0.0)
......
import typing
import unittest import unittest
from typing import List
import torch import torch
...@@ -143,7 +143,7 @@ class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest): ...@@ -143,7 +143,7 @@ class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
], ],
) )
def create_test_subjects(self) -> typing.List[Subject]: def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS) self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
self.disabled = self._create_subject(min_new_tokens=0.0) self.disabled = self._create_subject(min_new_tokens=0.0)
......
import typing
import unittest import unittest
from typing import List
import torch import torch
...@@ -76,7 +76,7 @@ class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest): ...@@ -76,7 +76,7 @@ class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
], ],
) )
def create_test_subjects(self) -> typing.List[Subject]: def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(presence_penalty=self.presence_penalty) self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
self.disabled = self._create_subject(presence_penalty=0.0) self.disabled = self._create_subject(presence_penalty=0.0)
......
import typing
import unittest import unittest
from typing import List
import torch import torch
...@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest): ...@@ -78,7 +78,7 @@ class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
], ],
) )
def create_test_subjects(self) -> typing.List[Subject]: def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY) self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
self.disabled = self._create_subject(repetition_penalty=1.0) self.disabled = self._create_subject(repetition_penalty=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment