Unverified Commit 4228be79 authored by Jialin Ouyang's avatar Jialin Ouyang Committed by GitHub
Browse files

[Perf] Use np.ndarray instead of list[list[int]] to reduce GC overhead (#28245)


Signed-off-by: default avatarJialin Ouyang <Jialin.Ouyang@gmail.com>
parent 76e4dcf2
...@@ -5,6 +5,7 @@ import random ...@@ -5,6 +5,7 @@ import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import TypeAlias from typing import TypeAlias
import numpy as np
import torch import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
...@@ -369,9 +370,9 @@ class MockEngineCore: ...@@ -369,9 +370,9 @@ class MockEngineCore:
self.generated_logprobs_raw[req_idx][token_idx] self.generated_logprobs_raw[req_idx][token_idx]
) )
logprobs = LogprobsLists( logprobs = LogprobsLists(
[logprobs_token_ids_], np.array([logprobs_token_ids_]),
[logprobs_], np.array([logprobs_]),
[sampled_token_ranks_], np.array([sampled_token_ranks_]),
) )
else: else:
logprobs = None logprobs = None
......
...@@ -74,7 +74,12 @@ class LogprobsProcessor: ...@@ -74,7 +74,12 @@ class LogprobsProcessor:
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): for rank_np, logprobs_np, token_ids_np in zip(
ranks_lst, logprobs_lst, token_ids_lst
):
rank = rank_np.tolist()
logprobs = logprobs_np.tolist()
token_ids = token_ids_np.tolist()
# Detokenize (non-incrementally). # Detokenize (non-incrementally).
decoded_tokens = ( decoded_tokens = (
NONES NONES
......
...@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -15,11 +16,11 @@ else: ...@@ -15,11 +16,11 @@ else:
class LogprobsLists(NamedTuple): class LogprobsLists(NamedTuple):
# [num_reqs x num_generated_tokens, max_num_logprobs + 1] # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: list[list[int]] logprob_token_ids: np.ndarray
# [num_reqs x num_generated_tokens, max_num_logprobs + 1] # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: list[list[float]] logprobs: np.ndarray
# [num_reqs x num_generated_tokens] # [num_reqs x num_generated_tokens]
sampled_token_ranks: list[int] sampled_token_ranks: np.ndarray
# [num_reqs] # [num_reqs]
# Used for slicing the logprobs in cases like speculative # Used for slicing the logprobs in cases like speculative
# decoding where the number of generated tokens may be # decoding where the number of generated tokens may be
...@@ -60,9 +61,9 @@ class LogprobsTensors(NamedTuple): ...@@ -60,9 +61,9 @@ class LogprobsTensors(NamedTuple):
def tolists(self, cu_num_generated_tokens: list[int] | None = None): def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists( return LogprobsLists(
self.logprob_token_ids.tolist(), self.logprob_token_ids.cpu().numpy(),
self.logprobs.tolist(), self.logprobs.cpu().numpy(),
self.selected_token_ranks.tolist(), self.selected_token_ranks.cpu().numpy(),
cu_num_generated_tokens, cu_num_generated_tokens,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment