Unverified Commit d4c29197 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Include private attributes in API documentation (#18614)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6220f3c6
...@@ -66,6 +66,7 @@ plugins: ...@@ -66,6 +66,7 @@ plugins:
options: options:
show_symbol_type_heading: true show_symbol_type_heading: true
show_symbol_type_toc: true show_symbol_type_toc: true
filters: []
summary: summary:
modules: true modules: true
show_if_no_docstring: true show_if_no_docstring: true
......
...@@ -262,16 +262,16 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -262,16 +262,16 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
True, then a token can be accepted, else it should be True, then a token can be accepted, else it should be
rejected. rejected.
Given {math}`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of
{math}`\hat{x}_{n+1}` given context {math}`x_1, \dots, x_n` according $\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according
to the target model, and {math}`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the
same conditional probability according to the draft model, the token same conditional probability according to the draft model, the token
is accepted with probability: is accepted with probability:
:::{math} $$
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
::: $$
This implementation does not apply causality. When using the output, This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used. if a token is rejected, subsequent tokens should not be used.
...@@ -314,27 +314,28 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -314,27 +314,28 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
target model is recovered (within hardware numerics). target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed The probability distribution used in this rejection case is constructed
as follows. Given {math}`q(x|x_1, \dots, x_n)`, the probability of as follows. Given $q(x|x_1, \dots, x_n)$, the probability of
{math}`x` given context {math}`x_1, \dots, x_n` according to the target $x$ given context $x_1, \dots, x_n$ according to the target
model and {math}`p(x|x_1, \dots, x_n)`, the same conditional probability model and $p(x|x_1, \dots, x_n)$, the same conditional probability
according to the draft model: according to the draft model:
:::{math} $$
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
::: $$
where {math}`(f(x))_+` is defined as: where $(f(x))_+$ is defined as:
:::{math} $$
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
::: $$
See https://github.com/vllm-project/vllm/pull/2336 for a visualization See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions. of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size]. Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered Note:
This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution. avoid that. This introduces some drift to the distribution.
......
...@@ -93,29 +93,27 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -93,29 +93,27 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Evaluates and returns a mask of accepted tokens based on the Evaluates and returns a mask of accepted tokens based on the
posterior probabilities. posterior probabilities.
Parameters: Args:
---------- target_probs (torch.Tensor): A tensor of shape
target_probs : torch.Tensor (batch_size, k, vocab_size) representing the probabilities of
A tensor of shape (batch_size, k, vocab_size) representing each token in the vocabulary for each position in the proposed
the probabilities of each token in the vocabulary for each sequence. This is the distribution generated by the target
position in the proposed sequence. This is the distribution model.
generated by the target model. draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k)
draft_token_ids : torch.Tensor representing the proposed token ids.
A tensor of shape (batch_size, k) representing the proposed
token ids.
A draft token_id x_{n+k} is accepted if it satisfies the A draft token_id x_{n+k} is accepted if it satisfies the
following condition following condition
:::{math} $$
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left( \min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}( -H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
::: $$
where {math}`p_{\text{original}}` corresponds to target_probs where $p_{\text{original}}$ corresponds to target_probs
and {math}`\epsilon` and {math}`\delta` correspond to hyperparameters and $\epsilon$ and $\delta$ correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given This method computes the posterior probabilities for the given
...@@ -126,13 +124,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -126,13 +124,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
returns a boolean mask indicating which tokens can be accepted. returns a boolean mask indicating which tokens can be accepted.
Returns: Returns:
------- torch.Tensor: A boolean tensor of shape (batch_size, k) where each
torch.Tensor element indicates whether the corresponding draft token has
A boolean tensor of shape (batch_size, k) where each element been accepted or rejected. True indicates acceptance and false
indicates whether the corresponding draft token has been accepted indicates rejection.
or rejected. True indicates acceptance and false indicates
rejection.
""" """
device = target_probs.device device = target_probs.device
candidates_prob = torch.gather( candidates_prob = torch.gather(
...@@ -156,17 +151,14 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -156,17 +151,14 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
The recovered token ids will fill the first unmatched token The recovered token ids will fill the first unmatched token
by the target token. by the target token.
Parameters Args:
---------- target_probs (torch.Tensor): A tensor of shape
target_probs : torch.Tensor (batch_size, k, vocab_size) containing the target probability
A tensor of shape (batch_size, k, vocab_size) containing distribution.
the target probability distribution
Returns:
Returns torch.Tensor: A tensor of shape (batch_size, k) with the recovered
------- token ids which are selected from target probs.
torch.Tensor
A tensor of shape (batch_size, k) with the recovered token
ids which are selected from target probs.
""" """
max_indices = torch.argmax(target_probs, dim=-1) max_indices = torch.argmax(target_probs, dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment