Unverified Commit 27acf63b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use torch.compile for scaling penalty (#3133)

parent da6f8081
import argparse import argparse
import itertools import itertools
import time
import torch import torch
import triton import triton
......
...@@ -3,11 +3,16 @@ from typing import List ...@@ -3,11 +3,16 @@ from typing import List
import torch import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import get_compiler_backend
is_cuda = is_cuda_available()
if is_cuda: @torch.compile(dynamic=True, backend=get_compiler_backend())
from sgl_kernel import sampling_scaling_penalties def apply_scaling_penalties(logits, scaling_penalties):
logits[:] = torch.where(
logits > 0,
logits / scaling_penalties,
logits * scaling_penalties,
)
class BatchedRepetitionPenalizer(_BatchedPenalizer): class BatchedRepetitionPenalizer(_BatchedPenalizer):
...@@ -61,16 +66,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): ...@@ -61,16 +66,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor: def _apply(self, logits: torch.Tensor) -> torch.Tensor:
if is_cuda: apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
return sampling_scaling_penalties(
logits, self.cumulated_repetition_penalties
)
else:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
......
...@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple ...@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from sglang.srt.utils import is_cuda_available
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import sampling_scaling_penalties
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
apply_scaling_penalties,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -386,14 +383,7 @@ class SamplingBatchInfo: ...@@ -386,14 +383,7 @@ class SamplingBatchInfo:
# repetition # repetition
if self.scaling_penalties is not None: if self.scaling_penalties is not None:
if is_cuda: apply_scaling_penalties(logits, self.scaling_penalties)
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
else:
logits[:] = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
)
# Apply regex vocab_mask # Apply regex vocab_mask
if self.vocab_mask is not None: if self.vocab_mask is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment