Commit 14356d5d authored by zhuwenwen's avatar zhuwenwen
Browse files

update list and type

parent 52675626
......@@ -956,7 +956,7 @@ def triton_int8_gemm_helper(m: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.float16,
out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda",
best_config:Optional[list] = None):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config)
......@@ -1749,8 +1749,8 @@ def free_shared_buffer(ptr: int) -> None:
def read_cache(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
......@@ -1761,8 +1761,8 @@ def read_cache(
def write_cache_multi_layers(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
......
......@@ -188,7 +188,7 @@ class SequenceData(msgspec.Struct,
@staticmethod
def from_prompt_token_counts(
*token_counts: Tuple[int, int]) -> "SequenceData":
*token_counts: tuple[int, int]) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance by concatenating
prompt token sequences.
......@@ -1334,9 +1334,9 @@ class Logits(msgspec.Struct, array_like=True,
# all tokens, whereas for decode step, it use used for last accepted tokens.
logits: torch.Tensor
# The sequence group metadata list. Only needed for decode step.
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
seq_group_metadata_list: Optional[list[SequenceGroupMetadata]] = None
_seq_ids: List[int] = msgspec.field(default_factory=list)
_seq_ids: list[int] = msgspec.field(default_factory=list)
def __post_init__(self):
if self.seq_group_metadata_list is not None:
......@@ -1344,12 +1344,12 @@ class Logits(msgspec.Struct, array_like=True,
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
@property
def seq_ids(self) -> List[int]:
def seq_ids(self) -> list[int]:
return self._seq_ids
def update(self,
logits: torch.Tensor,
seq_group_metadata_list: List[SequenceGroupMetadata]):
seq_group_metadata_list: list[SequenceGroupMetadata]):
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(logits)
......@@ -1357,7 +1357,7 @@ class Logits(msgspec.Struct, array_like=True,
self.logits = torch.cat([self.logits, logits])
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
seq_group_metadata_list: list[SequenceGroupMetadata]) -> None:
"""Prune to provided list of sequence ids. Only used for decode steps.
"""
# Currently this prunes all seq_ids not present in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment