Unverified Commit 7ae98875 authored by afeldman-nm's avatar afeldman-nm Committed by GitHub
Browse files

[V1] Logits processor docs (#22919)


Signed-off-by: default avatarAndrew Feldman <afeldman@redhat.com>
Signed-off-by: default avatarafeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
Co-authored-by: default avatarJoseph Marinier <Joseph.Marinier@gmail.com>
parent e3db5ebb
This diff is collapsed.
# Custom Arguments
You can use vLLM *custom arguments* to pass in arguments which are not part of the vLLM `SamplingParams` and REST API specifications. Adding or removing a vLLM custom argument does not require recompiling vLLM, since the custom arguments are passed in as a dictionary.
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
## Offline Custom Arguments
Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`:
``` python
SamplingParams(extra_args={"your_custom_arg_name": 67})
```
This allows arguments which are not already part of `SamplingParams` to be passed into `LLM` as part of a request.
## Online Custom Arguments
The vLLM REST API allows custom arguments to be passed to the vLLM server via `vllm_xargs`. The example below integrates custom arguments into a vLLM REST API request:
``` bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2.5-1.5B-Instruct",
...
"vllm_xargs": {"your_custom_arg": 67}
}'
```
Furthermore, OpenAI SDK users can access `vllm_xargs` via the `extra_body` argument:
``` python
batch = await client.completions.create(
model="Qwen/Qwen2.5-1.5B-Instruct",
...,
extra_body={
"vllm_xargs": {
"your_custom_arg": 67
}
}
)
```
!!! note
`vllm_xargs` is assigned to `SamplingParams.extra_args` under the hood, so code which uses `SamplingParams.extra_args` is compatible with both offline and online scenarios.
This diff is collapsed.
......@@ -56,7 +56,6 @@ class DummyLogitsProcessor(LogitsProcessor):
self.req_info: dict[int, int] = {}
def is_argmax_invariant(self) -> bool:
"""Never impacts greedy sampling"""
return False
def update_state(self, batch_update: Optional[BatchUpdate]):
......@@ -75,13 +74,12 @@ class DummyLogitsProcessor(LogitsProcessor):
return logits
# Save target values before modification
rows_list = list(self.req_info.keys())
cols = torch.tensor(
[self.req_info[i] for i in rows_list],
dtype=torch.long,
device=logits.device,
list(self.req_info.values()), dtype=torch.long, device=logits.device
)
rows = torch.tensor(
list(self.req_info.keys()), dtype=torch.long, device=logits.device
)
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
values_to_keep = logits[rows, cols].clone()
# Mask all but target tokens
......
......@@ -69,11 +69,12 @@ class DummyLogitsProcessor(LogitsProcessor):
return logits
# Save target values before modification
rows_list = list(self.req_info.keys())
cols = torch.tensor([self.req_info[i] for i in rows_list],
cols = torch.tensor(list(self.req_info.values()),
dtype=torch.long,
device=logits.device)
rows = torch.tensor(list(self.req_info.keys()),
dtype=torch.long,
device=logits.device)
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
values_to_keep = logits[rows, cols].clone()
# Mask all but target tokens
......
......@@ -21,6 +21,9 @@ class MoveDirectionality(Enum):
SWAP = auto()
# Batch indices of any removed requests.
RemovedRequest = int
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
# requests added to the batch.
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
......@@ -29,9 +32,6 @@ AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
# one-way moves or two-way swaps of requests in batch
MovedRequest = tuple[int, int, MoveDirectionality]
# Batch indices of any removed requests.
RemovedRequest = int
@dataclass(frozen=True)
class BatchUpdate:
......
......@@ -36,18 +36,18 @@ class BatchUpdateBuilder:
_removed: list[RemovedRequest]
_is_removed_sorted: bool
moved: list[MovedRequest]
added: list[AddedRequest]
moved: list[MovedRequest]
def __init__(
self,
removed: Optional[list[RemovedRequest]] = None,
moved: Optional[list[MovedRequest]] = None,
added: Optional[list[AddedRequest]] = None,
moved: Optional[list[MovedRequest]] = None,
) -> None:
self._removed = removed or []
self.moved = moved or []
self.added = added or []
self.moved = moved or []
self._is_removed_sorted = False
# Used to track changes in the pooling case
......@@ -107,8 +107,8 @@ class BatchUpdateBuilder:
"""Returns True if there were any changes to the batch."""
self._is_removed_sorted = False
self._removed.clear()
self.moved.clear()
self.added.clear()
self.moved.clear()
batch_changed = self.batch_changed
self.batch_changed = False
return batch_changed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment