Unverified Commit 4228be79 authored by Jialin Ouyang's avatar Jialin Ouyang Committed by GitHub
Browse files

[Perf] Use np.ndarray instead of list[list[int]] to reduce GC overhead (#28245)


Signed-off-by: default avatarJialin Ouyang <Jialin.Ouyang@gmail.com>
parent 76e4dcf2
......@@ -5,6 +5,7 @@ import random
from dataclasses import dataclass
from typing import TypeAlias
import numpy as np
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
......@@ -369,9 +370,9 @@ class MockEngineCore:
self.generated_logprobs_raw[req_idx][token_idx]
)
logprobs = LogprobsLists(
[logprobs_token_ids_],
[logprobs_],
[sampled_token_ranks_],
np.array([logprobs_token_ids_]),
np.array([logprobs_]),
np.array([sampled_token_ranks_]),
)
else:
logprobs = None
......
......@@ -74,7 +74,12 @@ class LogprobsProcessor:
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst):
for rank_np, logprobs_np, token_ids_np in zip(
ranks_lst, logprobs_lst, token_ids_lst
):
rank = rank_np.tolist()
logprobs = logprobs_np.tolist()
token_ids = token_ids_np.tolist()
# Detokenize (non-incrementally).
decoded_tokens = (
NONES
......
......@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import torch
if TYPE_CHECKING:
......@@ -15,11 +16,11 @@ else:
class LogprobsLists(NamedTuple):
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: list[list[int]]
logprob_token_ids: np.ndarray
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: list[list[float]]
logprobs: np.ndarray
# [num_reqs x num_generated_tokens]
sampled_token_ranks: list[int]
sampled_token_ranks: np.ndarray
# [num_reqs]
# Used for slicing the logprobs in cases like speculative
# decoding where the number of generated tokens may be
......@@ -60,9 +61,9 @@ class LogprobsTensors(NamedTuple):
def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
self.logprob_token_ids.cpu().numpy(),
self.logprobs.cpu().numpy(),
self.selected_token_ranks.cpu().numpy(),
cu_num_generated_tokens,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment