Unverified Commit 89a84b0b authored by Peng Guanwen's avatar Peng Guanwen Committed by GitHub
Browse files

[Core] Use array to speedup padding (#6779)

parent 084a01fd
...@@ -220,7 +220,7 @@ def _apply_min_tokens_penalty( ...@@ -220,7 +220,7 @@ def _apply_min_tokens_penalty(
seqs_to_penalize: List[int] = [] seqs_to_penalize: List[int] = []
for j, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens: if len(seq_data.output_token_ids_array) < min_tokens:
seqs_to_penalize.append(j) seqs_to_penalize.append(j)
if seqs_to_penalize: if seqs_to_penalize:
......
import random import random
from array import array
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -329,8 +330,8 @@ class SamplingTensors: ...@@ -329,8 +330,8 @@ class SamplingTensors:
user-defined seed for each sequence. user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds. extra_entropy: extra entropy to use when generating seeds.
""" """
prompt_tokens: List[List[int]] = [] prompt_tokens: List[array] = []
output_tokens: List[List[int]] = [] output_tokens: List[array] = []
top_ks: List[int] = [] top_ks: List[int] = []
temperatures: List[float] = [] temperatures: List[float] = []
top_ps: List[float] = [] top_ps: List[float] = []
...@@ -432,13 +433,15 @@ class SamplingTensors: ...@@ -432,13 +433,15 @@ class SamplingTensors:
if (seq_group.is_prompt if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices) prefill_len = len(seq_group.prompt_logprob_indices)
prompt_tokens.extend([] for _ in range(prefill_len)) prompt_tokens.extend(
output_tokens.extend([] for _ in range(prefill_len)) array('l') for _ in range(prefill_len))
output_tokens.extend(
array('l') for _ in range(prefill_len))
if seq_group.do_sample: if seq_group.do_sample:
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(list(seq_data.prompt_token_ids)) prompt_tokens.append(seq_data.prompt_token_ids_array)
output_tokens.append(list(seq_data.output_token_ids)) output_tokens.append(seq_data.output_token_ids_array)
sampling_tensors = SamplingTensors.from_lists( sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties, temperatures, top_ps, top_ks, min_ps, presence_penalties,
...@@ -454,9 +457,9 @@ class SamplingTensors: ...@@ -454,9 +457,9 @@ class SamplingTensors:
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float], repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int], sampling_seeds: List[int], sample_indices: List[int],
prompt_tokens: List[List[int]], prompt_tokens: List[array], output_tokens: List[array],
output_tokens: List[List[int]], vocab_size: int, vocab_size: int, extra_seeds_to_generate: int,
extra_seeds_to_generate: int, device: torch.device, device: torch.device,
dtype: torch.dtype) -> "SamplingTensors": dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.
......
...@@ -3,6 +3,7 @@ import copy ...@@ -3,6 +3,7 @@ import copy
import enum import enum
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
...@@ -119,10 +120,10 @@ class SequenceData: ...@@ -119,10 +120,10 @@ class SequenceData:
prompt_token_ids: List[int], prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None, output_token_ids: Optional[List[int]] = None,
) -> None: ) -> None:
self._prompt_token_ids: List[int] = list(prompt_token_ids) self._prompt_token_ids = array('l', prompt_token_ids)
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self._output_token_ids: List[int] = ( self._output_token_ids = array(
list(output_token_ids) if output_token_ids is not None else []) 'l', output_token_ids if output_token_ids is not None else [])
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model). # The number of tokens that are computed (that run against the model).
...@@ -132,7 +133,7 @@ class SequenceData: ...@@ -132,7 +133,7 @@ class SequenceData:
self._update_cached_all_tokens() self._update_cached_all_tokens()
def _update_cached_all_tokens(self): def _update_cached_all_tokens(self):
self._cached_all_token_ids: List[int] = (self._prompt_token_ids + self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._output_token_ids) self._output_token_ids)
@property @property
...@@ -141,19 +142,27 @@ class SequenceData: ...@@ -141,19 +142,27 @@ class SequenceData:
@prompt_token_ids.setter @prompt_token_ids.setter
def prompt_token_ids(self, new_prompt_token_ids) -> None: def prompt_token_ids(self, new_prompt_token_ids) -> None:
self._prompt_token_ids = list(new_prompt_token_ids) self._prompt_token_ids = array('l', new_prompt_token_ids)
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
@property
def prompt_token_ids_array(self) -> array:
return self._prompt_token_ids
@property @property
def output_token_ids(self) -> Tuple[int, ...]: def output_token_ids(self) -> Tuple[int, ...]:
return tuple(self._output_token_ids) return tuple(self._output_token_ids)
@output_token_ids.setter @output_token_ids.setter
def output_token_ids(self, new_output_token_ids) -> None: def output_token_ids(self, new_output_token_ids) -> None:
self._output_token_ids = list(new_output_token_ids) self._output_token_ids = array('l', new_output_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
@property
def output_token_ids_array(self) -> array:
return self._output_token_ids
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id) self._output_token_ids.append(token_id)
self._cached_all_token_ids.append(token_id) self._cached_all_token_ids.append(token_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment