Commit 14356d5d authored by zhuwenwen's avatar zhuwenwen
Browse files

update list and type

parent 52675626
...@@ -956,7 +956,7 @@ def triton_int8_gemm_helper(m: int, ...@@ -956,7 +956,7 @@ def triton_int8_gemm_helper(m: int,
per_token_act_quant: bool, per_token_act_quant: bool,
per_out_channel_weight_quant: bool, per_out_channel_weight_quant: bool,
use_bias: bool, use_bias: bool,
out_dtype: Type[torch.dtype] = torch.float16, out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda", device: str = "cuda",
best_config:Optional[list] = None): best_config:Optional[list] = None):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config) return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config)
...@@ -1749,8 +1749,8 @@ def free_shared_buffer(ptr: int) -> None: ...@@ -1749,8 +1749,8 @@ def free_shared_buffer(ptr: int) -> None:
def read_cache( def read_cache(
keys: torch.Tensor, keys: torch.Tensor,
values: torch.Tensor, values: torch.Tensor,
key_caches: List[torch.Tensor], key_caches: list[torch.Tensor],
value_caches: List[torch.Tensor], value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str kv_cache_dtype: str
) -> None: ) -> None:
...@@ -1761,8 +1761,8 @@ def read_cache( ...@@ -1761,8 +1761,8 @@ def read_cache(
def write_cache_multi_layers( def write_cache_multi_layers(
keys: torch.Tensor, keys: torch.Tensor,
values: torch.Tensor, values: torch.Tensor,
key_caches: List[torch.Tensor], key_caches: list[torch.Tensor],
value_caches: List[torch.Tensor], value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str kv_cache_dtype: str
) -> None: ) -> None:
......
...@@ -188,7 +188,7 @@ class SequenceData(msgspec.Struct, ...@@ -188,7 +188,7 @@ class SequenceData(msgspec.Struct,
@staticmethod @staticmethod
def from_prompt_token_counts( def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData": *token_counts: tuple[int, int]) -> "SequenceData":
""" """
Construct a :class:`SequenceData` instance by concatenating Construct a :class:`SequenceData` instance by concatenating
prompt token sequences. prompt token sequences.
...@@ -1334,9 +1334,9 @@ class Logits(msgspec.Struct, array_like=True, ...@@ -1334,9 +1334,9 @@ class Logits(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens. # all tokens, whereas for decode step, it use used for last accepted tokens.
logits: torch.Tensor logits: torch.Tensor
# The sequence group metadata list. Only needed for decode step. # The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
_seq_ids: List[int] = msgspec.field(default_factory=list) _seq_ids: list[int] = msgspec.field(default_factory=list)
def __post_init__(self): def __post_init__(self):
if self.seq_group_metadata_list is not None: if self.seq_group_metadata_list is not None:
...@@ -1344,12 +1344,12 @@ class Logits(msgspec.Struct, array_like=True, ...@@ -1344,12 +1344,12 @@ class Logits(msgspec.Struct, array_like=True,
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
@property @property
def seq_ids(self) -> List[int]: def seq_ids(self) -> list[int]:
return self._seq_ids return self._seq_ids
def update(self, def update(self,
logits: torch.Tensor, logits: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata]): seq_group_metadata_list: list[SequenceGroupMetadata]):
"""Update hidden states from target model invocation. Only used for """Update hidden states from target model invocation. Only used for
decode steps""" decode steps"""
assert len(seq_group_metadata_list) == len(logits) assert len(seq_group_metadata_list) == len(logits)
...@@ -1357,7 +1357,7 @@ class Logits(msgspec.Struct, array_like=True, ...@@ -1357,7 +1357,7 @@ class Logits(msgspec.Struct, array_like=True,
self.logits = torch.cat([self.logits, logits]) self.logits = torch.cat([self.logits, logits])
def prune(self, def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids. Only used for decode steps. """Prune to provided list of sequence ids. Only used for decode steps.
""" """
# Currently this prunes all seq_ids not present in # Currently this prunes all seq_ids not present in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment