Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
...@@ -271,7 +271,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -271,7 +271,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
""" """
source_blocks = get_all_blocks_recursively(last_block) source_blocks = get_all_blocks_recursively(last_block)
forked_blocks = [] forked_blocks: List[Block] = []
prev_block = None prev_block = None
for block in source_blocks: for block in source_blocks:
refcount = self._refcounter.incr(block.block_id) refcount = self._refcounter.incr(block.block_id)
......
...@@ -260,7 +260,7 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -260,7 +260,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# at max extend. # at max extend.
if self.enable_caching: if self.enable_caching:
block_table = self.block_tables[seq.seq_id] block_table = self.block_tables[seq.seq_id]
block_ids = [] block_ids: List[Optional[int]] = []
for block_id in block_table.physical_block_ids: for block_id in block_table.physical_block_ids:
block_ids.append(block_id) block_ids.append(block_id)
self.block_allocator.mark_blocks_as_accessed( self.block_allocator.mark_blocks_as_accessed(
......
...@@ -2,7 +2,7 @@ import ctypes ...@@ -2,7 +2,7 @@ import ctypes
import json import json
import os import os
from itertools import product from itertools import product
from typing import Dict, Optional, Sequence from typing import Dict, List, Optional, Sequence
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -88,7 +88,7 @@ def consumer(batch_tgt: Sequence[int], ...@@ -88,7 +88,7 @@ def consumer(batch_tgt: Sequence[int],
def can_actually_p2p( def can_actually_p2p(
batch_src: Sequence[int], batch_src: Sequence[int],
batch_tgt: Sequence[int], batch_tgt: Sequence[int],
): ) -> Sequence[bool]:
""" """
Usually, checking if P2P access is enabled can be done by Usually, checking if P2P access is enabled can be done by
`torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes
...@@ -138,7 +138,7 @@ def can_actually_p2p( ...@@ -138,7 +138,7 @@ def can_actually_p2p(
p_tgt.start() p_tgt.start()
p_src.join() p_src.join()
p_tgt.join() p_tgt.join()
result = [] result: List[bool] = []
for src, tgt in zip(batch_src, batch_tgt): for src, tgt in zip(batch_src, batch_tgt):
a = result_queue.get() a = result_queue.get()
b = result_queue.get() b = result_queue.get()
...@@ -188,7 +188,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: ...@@ -188,7 +188,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
logger.info("generating GPU P2P access cache in %s", path) logger.info("generating GPU P2P access cache in %s", path)
cache = {} cache: Dict[str, bool] = {}
ids = list(range(num_dev)) ids = list(range(num_dev))
# batch of all pairs of GPUs # batch of all pairs of GPUs
batch_src, batch_tgt = zip(*list(product(ids, ids))) batch_src, batch_tgt = zip(*list(product(ids, ids)))
......
...@@ -205,7 +205,7 @@ class NCCLLibrary: ...@@ -205,7 +205,7 @@ class NCCLLibrary:
raise e raise e
if so_file not in NCCLLibrary.path_to_dict_mapping: if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs = {} _funcs: Dict[str, Any] = {}
for func in NCCLLibrary.exported_functions: for func in NCCLLibrary.exported_functions:
f = getattr(self.lib, func.name) f = getattr(self.lib, func.name)
f.restype = func.restype f.restype = func.restype
......
...@@ -2,7 +2,7 @@ import time ...@@ -2,7 +2,7 @@ import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union from typing import Set, Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
...@@ -973,7 +973,7 @@ class LLMEngine: ...@@ -973,7 +973,7 @@ class LLMEngine:
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id) return self.model_executor.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.model_executor.list_loras() return self.model_executor.list_loras()
def check_health(self) -> None: def check_health(self) -> None:
......
...@@ -144,7 +144,7 @@ class Metrics: ...@@ -144,7 +144,7 @@ class Metrics:
# end-metrics-definitions # end-metrics-definitions
def build_1_2_5_buckets(max_value: int): def build_1_2_5_buckets(max_value: int) -> List[int]:
""" """
Builds a list of buckets with increasing powers of 10 multiplied by Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values (1, 2, 5) until the value exceeds the specified maximum. mantissa values (1, 2, 5) until the value exceeds the specified maximum.
...@@ -155,7 +155,7 @@ def build_1_2_5_buckets(max_value: int): ...@@ -155,7 +155,7 @@ def build_1_2_5_buckets(max_value: int):
""" """
mantissa_lst = [1, 2, 5] mantissa_lst = [1, 2, 5]
exponent = 0 exponent = 0
buckets = [] buckets: List[int] = []
while True: while True:
for m in mantissa_lst: for m in mantissa_lst:
value = m * 10**exponent value = m * 10**exponent
......
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
...@@ -146,8 +146,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -146,8 +146,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Beam search case # Beam search case
# Select the child sequences to keep in the sequence group. # Select the child sequences to keep in the sequence group.
selected_child_seqs = [] selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs = [] unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = seq_group.sampling_params.best_of beam_width = seq_group.sampling_params.best_of
length_penalty = seq_group.sampling_params.length_penalty length_penalty = seq_group.sampling_params.length_penalty
......
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
import asyncio import asyncio
import sys import sys
from io import StringIO from io import StringIO
from typing import Awaitable, List
import aiohttp import aiohttp
...@@ -114,7 +115,7 @@ async def main(args): ...@@ -114,7 +115,7 @@ async def main(args):
) )
# Submit all requests in the file to the engine "concurrently". # Submit all requests in the file to the engine "concurrently".
response_futures = [] response_futures: List[Awaitable[BatchRequestOutput]] = []
for request_json in (await read_file(args.input_file)).strip().split("\n"): for request_json in (await read_file(args.input_file)).strip().split("\n"):
request = BatchRequestInput.model_validate_json(request_json) request = BatchRequestInput.model_validate_json(request_json)
response_futures.append(run_request(openai_serving_chat, request)) response_futures.append(run_request(openai_serving_chat, request))
......
...@@ -487,7 +487,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -487,7 +487,7 @@ class OpenAIServingChat(OpenAIServing):
final_res = res final_res = res
assert final_res is not None assert final_res is not None
choices = [] choices: List[ChatCompletionResponseChoice] = []
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for output in final_res.outputs: for output in final_res.outputs:
......
...@@ -25,7 +25,7 @@ def request_output_to_embedding_response( ...@@ -25,7 +25,7 @@ def request_output_to_embedding_response(
created_time: int, created_time: int,
model_name: str, model_name: str,
) -> EmbeddingResponse: ) -> EmbeddingResponse:
data = [] data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0 num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch): for idx, final_res in enumerate(final_res_batch):
assert final_res is not None assert final_res is not None
......
from typing import List, Optional from typing import List, Optional
from typing import Sequence as GenericSequence
import torch import torch
...@@ -120,7 +121,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): ...@@ -120,7 +121,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
@classmethod @classmethod
def pack( def pack(
cls, loras: List[Optional["LoRALayerWeights"]] cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights": ) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA. """Pack a list of LoRAs into a single LoRA.
......
...@@ -165,7 +165,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager): ...@@ -165,7 +165,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
model = self._lora_manager.model model = self._lora_manager.model
supported_lora_modules = model.supported_lora_modules supported_lora_modules = model.supported_lora_modules
packed_modules_mapping = model.packed_modules_mapping packed_modules_mapping = model.packed_modules_mapping
expected_lora_modules = [] expected_lora_modules: List[str] = []
for module in supported_lora_modules: for module in supported_lora_modules:
if module in packed_modules_mapping: if module in packed_modules_mapping:
expected_lora_modules.extend( expected_lora_modules.extend(
......
...@@ -393,7 +393,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -393,7 +393,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
return return
current_shard_offset = 0 current_shard_offset = 0
shard_offsets = [] shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes): for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size current_shard_offset += output_size
......
...@@ -25,24 +25,25 @@ GPTQ_MARLIN_SUPPORTED_SYM = [True] ...@@ -25,24 +25,25 @@ GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling # Permutations for Marlin scale shuffling
def get_scale_perms(num_bits): def get_scale_perms(num_bits: int):
scale_perm = [] scale_perm: List[int] = []
for i in range(8): for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = [] scale_perm_single: List[int] = []
for i in range(4): for i in range(4):
scale_perm_single.extend( scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single return scale_perm, scale_perm_single
def get_pack_factor(num_bits): def get_pack_factor(num_bits: int):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}" ), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits) scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1: if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm] s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
......
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from typing import Dict, List
import numpy import numpy
import torch import torch
...@@ -11,10 +13,10 @@ import torch ...@@ -11,10 +13,10 @@ import torch
# #
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 # As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501 # (without the need to use ldmatrix instructions) # noqa: E501
def get_perms_24(num_bits): def get_perms_24(num_bits: int):
perm_list = [] perm_list: List[int] = []
for i in range(32): for i in range(32):
perm1 = [] perm1: List[int] = []
col = i // 4 col = i // 4
col_o = col // 2 col_o = col // 2
for block in [0, 1]: for block in [0, 1]:
...@@ -39,18 +41,18 @@ def get_perms_24(num_bits): ...@@ -39,18 +41,18 @@ def get_perms_24(num_bits):
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm) perm = torch.from_numpy(perm)
scale_perm = [] scale_perm: List[int] = []
for i in range(8): for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single = [] scale_perm_single: List[int] = []
for i in range(8): for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return perm, scale_perm, scale_perm_single return perm, scale_perm, scale_perm_single
marlin_24_perm = {} marlin_24_perm: Dict[int, torch.Tensor] = {}
marlin_24_scale_perm = {} marlin_24_scale_perm: Dict[int, List[int]] = {}
marlin_24_scale_perm_single = {} marlin_24_scale_perm_single: Dict[int, List[int]] = {}
for num_bits in [4, 8]: for num_bits in [4, 8]:
perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits) perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)
marlin_24_perm[num_bits] = perm_24 marlin_24_perm[num_bits] = perm_24
......
"""This file is used for /tests and /benchmarks""" """This file is used for /tests and /benchmarks"""
from typing import Dict, List
import numpy import numpy
import torch import torch
...@@ -11,10 +13,10 @@ import torch ...@@ -11,10 +13,10 @@ import torch
# #
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501 # As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501 # (without the need to use ldmatrix instructions) # noqa: E501
def get_perms(num_bits): def get_perms(num_bits: int):
perm_list = [] perm_list: List[int] = []
for i in range(32): for i in range(32):
perm1 = [] perm1: List[int] = []
col = i // 4 col = i // 4
for block in [0, 1]: for block in [0, 1]:
for row in [ for row in [
...@@ -38,19 +40,19 @@ def get_perms(num_bits): ...@@ -38,19 +40,19 @@ def get_perms(num_bits):
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm) perm = torch.from_numpy(perm)
scale_perm = [] scale_perm: List[int] = []
for i in range(8): for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = [] scale_perm_single: List[int] = []
for i in range(4): for i in range(4):
scale_perm_single.extend( scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single return perm, scale_perm, scale_perm_single
marlin_perm = {} marlin_perm: Dict[int, torch.Tensor] = {}
marlin_scale_perm = {} marlin_scale_perm: Dict[int, List[int]] = {}
marlin_scale_perm_single = {} marlin_scale_perm_single: Dict[int, List[int]] = {}
for num_bits in [4, 8]: for num_bits in [4, 8]:
perm, scale_perm, scale_perm_single = get_perms(num_bits) perm, scale_perm, scale_perm_single = get_perms(num_bits)
marlin_perm[num_bits] = perm marlin_perm[num_bits] = perm
......
...@@ -174,7 +174,7 @@ def _apply_min_tokens_penalty( ...@@ -174,7 +174,7 @@ def _apply_min_tokens_penalty(
min_tokens = sampling_params.min_tokens min_tokens = sampling_params.min_tokens
token_ids_to_penalize = sampling_params.all_stop_token_ids token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize: if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = [] seqs_to_penalize: List[int] = []
for j, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens: if len(seq_data.output_token_ids) < min_tokens:
...@@ -285,7 +285,7 @@ def _greedy_sample( ...@@ -285,7 +285,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
samples = samples.tolist() samples_lst = samples.tolist()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
...@@ -298,7 +298,7 @@ def _greedy_sample( ...@@ -298,7 +298,7 @@ def _greedy_sample(
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples[sample_idx]] next_token_ids = [samples_lst[sample_idx]]
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
return results return results
...@@ -394,7 +394,7 @@ def _beam_search_sample( ...@@ -394,7 +394,7 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
# Generation phase. # Generation phase.
cumulative_logprobs: List[int] = [ cumulative_logprobs: List[float] = [
seq_group.seq_data[seq_id].cumulative_logprob seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids for seq_id in seq_ids
] ]
...@@ -466,8 +466,9 @@ def _sample_with_torch( ...@@ -466,8 +466,9 @@ def _sample_with_torch(
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {} sample_metadata: Dict[SamplingType,
multinomial_samples = {} Tuple[List[int], List[SequenceGroupToSample]]] = {}
multinomial_samples: Dict[SamplingType, torch.Tensor] = {}
# Create output tensor for sampled token ids. # Create output tensor for sampled token ids.
if include_gpu_probs_tensor: if include_gpu_probs_tensor:
...@@ -494,7 +495,7 @@ def _sample_with_torch( ...@@ -494,7 +495,7 @@ def _sample_with_torch(
greedy_samples = torch.argmax(logprobs[long_sample_indices], greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1) dim=-1)
if include_gpu_probs_tensor: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[ sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1) long_sample_indices] = greedy_samples.unsqueeze(-1)
...@@ -522,7 +523,7 @@ def _sample_with_torch( ...@@ -522,7 +523,7 @@ def _sample_with_torch(
probs[long_sample_indices], max_best_of_in_batch, probs[long_sample_indices], max_best_of_in_batch,
**seeded_args) **seeded_args)
if include_gpu_probs_tensor: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[ sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type] long_sample_indices] = multinomial_samples[sampling_type]
...@@ -571,7 +572,9 @@ def _sample_with_triton_kernel( ...@@ -571,7 +572,9 @@ def _sample_with_triton_kernel(
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {} sample_metadata: Dict[SamplingType,
Tuple[List[int], List[SequenceGroupToSample],
torch.Tensor, torch.Tensor]] = {}
max_best_of_in_batch = 1 max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
...@@ -1008,14 +1011,14 @@ def _build_sampler_output( ...@@ -1008,14 +1011,14 @@ def _build_sampler_output(
speculative decoding rejection sampling. speculative decoding rejection sampling.
""" """
sampler_output = [] sampler_output: List[CompletionSequenceGroupOutput] = []
for (seq_group, sample_result, group_prompt_logprobs, for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups, group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs, sample_results, prompt_logprobs,
sample_logprobs): sample_logprobs):
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
seq_outputs = [] seq_outputs: List[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(parent_ids, for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids, next_token_ids,
group_sample_logprobs): group_sample_logprobs):
......
...@@ -68,7 +68,7 @@ def _get_model_initialization_kwargs( ...@@ -68,7 +68,7 @@ def _get_model_initialization_kwargs(
vision_language_config: Optional[VisionLanguageConfig] vision_language_config: Optional[VisionLanguageConfig]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Get extra kwargs for model initialization.""" """Get extra kwargs for model initialization."""
extra_kwargs = {} extra_kwargs: Dict[str, Any] = {}
if hasattr(model_class, "supported_lora_modules"): if hasattr(model_class, "supported_lora_modules"):
extra_kwargs["lora_config"] = lora_config extra_kwargs["lora_config"] = lora_config
elif lora_config: elif lora_config:
...@@ -446,7 +446,8 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -446,7 +446,8 @@ class ShardedStateLoader(BaseModelLoader):
Filter out all tensors that share the same memory or a subset of the Filter out all tensors that share the same memory or a subset of the
memory of another tensor. memory of another tensor.
""" """
same_storage_groups = collections.defaultdict(list) same_storage_groups: Dict[Any, List[Tuple[
str, torch.Tensor]]] = collections.defaultdict(list)
for key, tensor in tensors.items(): for key, tensor in tensors.items():
if tensor.numel(): if tensor.numel():
ptr = tensor.untyped_storage().data_ptr() ptr = tensor.untyped_storage().data_ptr()
...@@ -455,7 +456,7 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -455,7 +456,7 @@ class ShardedStateLoader(BaseModelLoader):
def get_end_ptr(tensor: torch.Tensor) -> int: def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size() return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
result = {} result: Dict[str, torch.Tensor] = {}
for group in same_storage_groups.values(): for group in same_storage_groups.values():
for k, t in group: for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t) a, b = t.data_ptr(), get_end_ptr(t)
......
...@@ -329,7 +329,7 @@ def np_cache_weights_iterator( ...@@ -329,7 +329,7 @@ def np_cache_weights_iterator(
# dumping the same model weights to numpy at the same time. # dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file): if not os.path.exists(weight_names_file):
weight_names = [] weight_names: List[str] = []
for bin_file in hf_weights_files: for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu") state = torch.load(bin_file, map_location="cpu")
for name, param in state.items(): for name, param in state.items():
......
...@@ -72,11 +72,11 @@ _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS} ...@@ -72,11 +72,11 @@ _MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
_OOT_MODELS: Dict[str, Type[nn.Module]] = {} _OOT_MODELS: Dict[str, Type[nn.Module]] = {}
# Models not supported by ROCm. # Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS = [] _ROCM_UNSUPPORTED_MODELS: List[str] = []
# Models partially supported by ROCm. # Models partially supported by ROCm.
# Architecture -> Reason. # Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = { _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM": "Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention", "Sliding window attention is not yet supported in ROCm's flash attention",
"MistralForCausalLM": "MistralForCausalLM":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment