min_new_tokens.py 3.33 KB
Newer Older
1
2
import torch

3
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27


class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
    """
    Min new tokens penalizer penalizes tokens based on the length of the output.
    """

    def _is_required(self) -> bool:
        return any(
            req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
        )

    def _prepare(self):
        self.min_new_tokens = torch.tensor(
            data=[
                req.sampling_params.min_new_tokens for req in self.orchestrator.reqs()
            ],
            dtype=torch.int32,
            device=self.orchestrator.device,
        ).unsqueeze_(1)

        padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
            sequences=[
                torch.tensor(
28
29
30
31
32
33
                    data=(
                        list(
                            (req.sampling_params.stop_token_ids or set())
                            | (req.tokenizer.additional_stop_token_ids or set())
                            | {req.tokenizer.eos_token_id}
                        )
34
35
36
37
38
39
40
41
42
43
                    ),
                    dtype=torch.int64,
                    device=self.orchestrator.device,
                )
                for req in self.orchestrator.reqs()
            ],
            batch_first=True,
            padding_value=self.orchestrator.vocab_size,
        )
        self.stop_token_penalties = torch.zeros(
44
            size=(len(self.orchestrator.reqs()), self.orchestrator.vocab_size + 1),
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
            dtype=torch.float32,
            device=self.orchestrator.device,
        ).scatter_add_(
            dim=1,
            index=padded_stop_token_ids,
            src=torch.full_like(
                input=padded_stop_token_ids,
                dtype=torch.float32,
                fill_value=float("-inf"),
                device=self.orchestrator.device,
            ),
        )[
            :, : self.orchestrator.vocab_size
        ]

        self.len_output_tokens = torch.zeros(
61
            size=(len(self.orchestrator.reqs()), 1),
62
63
64
65
            dtype=torch.int32,
            device=self.orchestrator.device,
        )

66
    def _cumulate_output_tokens(self, output_ids: torch.Tensor):
67
68
        self.len_output_tokens += 1

69
    def _apply(self, logits: torch.Tensor):
70
71
72
        mask = (self.len_output_tokens < self.min_new_tokens).expand_as(logits)
        logits[mask] += self.stop_token_penalties[mask]

73
74
75
76
    def _filter(self, keep_indices: torch.Tensor):
        self.min_new_tokens = self.min_new_tokens[keep_indices]
        self.stop_token_penalties = self.stop_token_penalties[keep_indices]
        self.len_output_tokens = self.len_output_tokens[keep_indices]
77
78
79
80
81
82
83
84
85
86
87

    def _merge(self, their: "BatchedMinNewTokensPenalizer"):
        self.min_new_tokens = torch.cat(
            [self.min_new_tokens, their.min_new_tokens], dim=0
        )
        self.stop_token_penalties = torch.cat(
            [self.stop_token_penalties, their.stop_token_penalties], dim=0
        )
        self.len_output_tokens = torch.cat(
            [self.len_output_tokens, their.len_output_tokens], dim=0
        )
88
89
90
91
92
93

    # Explicit resource cleanup to aid GC and free CUDA memory promptly
    def _teardown(self) -> None:
        for name in ("min_new_tokens", "stop_token_penalties", "len_output_tokens"):
            if hasattr(self, name):
                delattr(self, name)