Unverified Commit 3f770f44 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Performance] Cache loaded custom logitsprocs to avoid overheads (#28462)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 48c87936
...@@ -5,7 +5,7 @@ import inspect ...@@ -5,7 +5,7 @@ import inspect
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import lru_cache, partial
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
...@@ -216,11 +216,17 @@ def build_logitsprocs( ...@@ -216,11 +216,17 @@ def build_logitsprocs(
) )
cached_load_custom_logitsprocs = lru_cache(_load_custom_logitsprocs)
def validate_logits_processors_parameters( def validate_logits_processors_parameters(
logits_processors: Sequence[str | type[LogitsProcessor]] | None, logits_processors: Sequence[str | type[LogitsProcessor]] | None,
sampling_params: SamplingParams, sampling_params: SamplingParams,
): ):
for logits_procs in _load_custom_logitsprocs(logits_processors): logits_processors = (
tuple(logits_processors) if logits_processors is not None else None
)
for logits_procs in cached_load_custom_logitsprocs(logits_processors):
logits_procs.validate_params(sampling_params) logits_procs.validate_params(sampling_params)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment