Unverified Commit 0e9164b4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[mypy] Enable type checking for test directory (#5017)

parent 1b8a0d71
...@@ -47,5 +47,5 @@ jobs: ...@@ -47,5 +47,5 @@ jobs:
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml mypy tests --config-file pyproject.toml
...@@ -31,7 +31,7 @@ import time ...@@ -31,7 +31,7 @@ import time
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Optional, Tuple from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
import numpy as np import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
...@@ -200,12 +200,12 @@ def calculate_metrics( ...@@ -200,12 +200,12 @@ def calculate_metrics(
dur_s: float, dur_s: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> Tuple[BenchmarkMetrics, List[int]]: ) -> Tuple[BenchmarkMetrics, List[int]]:
actual_output_lens = [] actual_output_lens: List[int] = []
total_input = 0 total_input = 0
completed = 0 completed = 0
itls = [] itls: List[float] = []
tpots = [] tpots: List[float] = []
ttfts = [] ttfts: List[float] = []
for i in range(len(outputs)): for i in range(len(outputs)):
if outputs[i].success: if outputs[i].success:
# We use the tokenizer to count the number of output tokens for all # We use the tokenizer to count the number of output tokens for all
...@@ -265,7 +265,7 @@ async def benchmark( ...@@ -265,7 +265,7 @@ async def benchmark(
disable_tqdm: bool, disable_tqdm: bool,
): ):
if backend in ASYNC_REQUEST_FUNCS: if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS.get(backend) request_func = ASYNC_REQUEST_FUNCS[backend]
else: else:
raise ValueError(f"Unknown backend: {backend}") raise ValueError(f"Unknown backend: {backend}")
...@@ -292,7 +292,7 @@ async def benchmark( ...@@ -292,7 +292,7 @@ async def benchmark(
pbar = None if disable_tqdm else tqdm(total=len(input_requests)) pbar = None if disable_tqdm else tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
tasks = [] tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate): async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request prompt, prompt_len, output_len = request
request_func_input = RequestFuncInput( request_func_input = RequestFuncInput(
...@@ -310,7 +310,7 @@ async def benchmark( ...@@ -310,7 +310,7 @@ async def benchmark(
pbar=pbar))) pbar=pbar)))
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
if not disable_tqdm: if pbar is not None:
pbar.close() pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time benchmark_duration = time.perf_counter() - benchmark_start_time
...@@ -466,7 +466,7 @@ def main(args: argparse.Namespace): ...@@ -466,7 +466,7 @@ def main(args: argparse.Namespace):
# Save config and results to json # Save config and results to json
if args.save_result: if args.save_result:
result_json = {} result_json: Dict[str, Any] = {}
# Setup # Setup
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S") current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
......
...@@ -108,8 +108,8 @@ def run_vllm( ...@@ -108,8 +108,8 @@ def run_vllm(
) )
# Add the requests to the engine. # Add the requests to the engine.
prompts = [] prompts: List[str] = []
sampling_params = [] sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
prompts.append(prompt) prompts.append(prompt)
sampling_params.append( sampling_params.append(
......
...@@ -86,9 +86,9 @@ def dequant_no_scale( ...@@ -86,9 +86,9 @@ def dequant_no_scale(
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against # Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version. # the generic pytorch version.
# Just visual comparison. # Just visual comparison.
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None: def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
n = parts.sum().item() n = int(parts.sum().item())
device = torch.device('cuda:0') device = torch.device('cuda:0')
...@@ -204,7 +204,7 @@ def main(): ...@@ -204,7 +204,7 @@ def main():
sys.stdout = sys.__stdout__ sys.stdout = sys.__stdout__
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int,
methods): methods):
# I didn't see visible improvements from increasing these, but feel free :) # I didn't see visible improvements from increasing these, but feel free :)
...@@ -252,10 +252,10 @@ def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int, ...@@ -252,10 +252,10 @@ def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
print('') print('')
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor, def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor,
nbooks: int, bits: int, method) -> float: nbooks: int, bits: int, method) -> float:
n = parts.sum().item() n = int(parts.sum().item())
device = torch.device('cuda:0') device = torch.device('cuda:0')
......
import argparse import argparse
from typing import List
import torch import torch
import torch.utils.benchmark as benchmark import torch.utils.benchmark as benchmark
...@@ -23,8 +24,9 @@ ACT_ORDER_OPTS = [False, True] ...@@ -23,8 +24,9 @@ ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True] K_FULL_OPTS = [False, True]
def bench_run(results, model, act_order, is_k_full, num_bits, group_size, def bench_run(results: List[benchmark.Measurement], model: str,
size_m, size_k, size_n): act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
size_m: int, size_k: int, size_n: int):
label = "Quant Matmul" label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, " sub_label = ("{}, act={} k_full={}, b={}, g={}, "
...@@ -156,7 +158,7 @@ def main(args): ...@@ -156,7 +158,7 @@ def main(args):
for i, model in enumerate(args.models): for i, model in enumerate(args.models):
print(f"[{i}] {model}") print(f"[{i}] {model}")
results = [] results: List[benchmark.Measurement] = []
for model in args.models: for model in args.models:
for layer in WEIGHT_SHAPES[model]: for layer in WEIGHT_SHAPES[model]:
......
import argparse import argparse
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, TypedDict
import ray import ray
import torch import torch
...@@ -12,8 +12,17 @@ from transformers import AutoConfig ...@@ -12,8 +12,17 @@ from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.model_executor.layers.fused_moe.fused_moe import *
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_config( def benchmark_config(
config: Dict[str, int], config: BenchmarkConfig,
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
shard_intermediate_size: int, shard_intermediate_size: int,
...@@ -92,7 +101,7 @@ def benchmark_config( ...@@ -92,7 +101,7 @@ def benchmark_config(
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
latencies = [] latencies: List[float] = []
for i in range(num_iters): for i in range(num_iters):
prepare(i) prepare(i)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -111,7 +120,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: ...@@ -111,7 +120,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
# Reduced search space for faster tuning. # Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to # TODO(woosuk): Increase the search space and use a performance model to
# prune the search space. # prune the search space.
configs = [] configs: List[BenchmarkConfig] = []
for num_stages in [2, 3, 4, 5]: for num_stages in [2, 3, 4, 5]:
for block_m in [16, 32, 64, 128, 256]: for block_m in [16, 32, 64, 128, 256]:
for block_k in [64, 128, 256]: for block_k in [64, 128, 256]:
...@@ -175,8 +184,8 @@ class BenchmarkWorker: ...@@ -175,8 +184,8 @@ class BenchmarkWorker:
topk: int, topk: int,
dtype: torch.dtype, dtype: torch.dtype,
use_fp8: bool, use_fp8: bool,
search_space: List[Dict[str, int]], search_space: List[BenchmarkConfig],
) -> Dict[str, int]: ) -> BenchmarkConfig:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
for config in tqdm(search_space): for config in tqdm(search_space):
...@@ -199,10 +208,11 @@ class BenchmarkWorker: ...@@ -199,10 +208,11 @@ class BenchmarkWorker:
best_config = config best_config = config
now = datetime.now() now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
return best_config return best_config
def sort_config(config: Dict[str, int]) -> Dict[str, int]: def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return { return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"], "BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"], "BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
...@@ -214,7 +224,7 @@ def sort_config(config: Dict[str, int]) -> Dict[str, int]: ...@@ -214,7 +224,7 @@ def sort_config(config: Dict[str, int]) -> Dict[str, int]:
def save_configs( def save_configs(
configs: Dict[int, Dict[str, int]], configs: Dict[int, BenchmarkConfig],
num_experts: int, num_experts: int,
shard_intermediate_size: int, shard_intermediate_size: int,
hidden_size: int, hidden_size: int,
......
import argparse import argparse
import random import random
import time import time
from typing import Optional from typing import List, Optional
import torch import torch
...@@ -54,14 +54,17 @@ def main( ...@@ -54,14 +54,17 @@ def main(
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = [] block_tables_lst: List[List[int]] = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
random.randint(0, NUM_BLOCKS - 1) random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq) for _ in range(max_num_blocks_per_seq)
] ]
block_tables.append(block_table) block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
block_tables = torch.tensor(block_tables_lst,
dtype=torch.int,
device=device)
# Create the KV cache. # Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
......
import argparse import argparse
from itertools import accumulate from itertools import accumulate
from typing import Optional from typing import List, Optional
import nvtx import nvtx
import torch import torch
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)
def benchmark_rope_kernels_multi_lora( def benchmark_rope_kernels_multi_lora(
...@@ -37,7 +38,7 @@ def benchmark_rope_kernels_multi_lora( ...@@ -37,7 +38,7 @@ def benchmark_rope_kernels_multi_lora(
}) })
# non-batched RoPE takes only one scaling factor, we create multiple # non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior # instances to simulate the same behavior
non_batched_ropes = [] non_batched_ropes: List[RotaryEmbedding] = []
for scaling_factor in scaling_factors: for scaling_factor in scaling_factors:
non_batched_ropes.append( non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style, get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
......
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import glob import glob
import json import json
import os import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -19,7 +19,7 @@ def _prepare_hf_weights( ...@@ -19,7 +19,7 @@ def _prepare_hf_weights(
quantized_model_dir: str, quantized_model_dir: str,
load_format: str = "auto", load_format: str = "auto",
fall_back_to_pt: bool = True, fall_back_to_pt: bool = True,
) -> Tuple[str, List[str], bool]: ) -> Tuple[List[str], bool]:
if not os.path.isdir(quantized_model_dir): if not os.path.isdir(quantized_model_dir):
raise FileNotFoundError( raise FileNotFoundError(
f"The quantized model directory `{quantized_model_dir}` " f"The quantized model directory `{quantized_model_dir}` "
...@@ -94,7 +94,7 @@ def _hf_tensorfile_iterator(filename: str, load_format: str, ...@@ -94,7 +94,7 @@ def _hf_tensorfile_iterator(filename: str, load_format: str,
def _kv_scales_extractor( def _kv_scales_extractor(
hf_tensor_files: Iterable[str], hf_tensor_files: List[str],
use_safetensors: bool, use_safetensors: bool,
rank_keyword: str = "rank", rank_keyword: str = "rank",
expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]: expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]:
...@@ -115,7 +115,7 @@ def _kv_scales_extractor( ...@@ -115,7 +115,7 @@ def _kv_scales_extractor(
for char in rank_keyword: for char in rank_keyword:
assert not char.isdecimal( assert not char.isdecimal(
), f"Rank keyword {rank_keyword} contains a numeric character!" ), f"Rank keyword {rank_keyword} contains a numeric character!"
rank_scales_map = {} rank_scales_map: Dict[int, Dict[int, float]] = {}
for tensor_file in hf_tensor_files: for tensor_file in hf_tensor_files:
try: try:
rank_idx = tensor_file.find(rank_keyword) rank_idx = tensor_file.find(rank_keyword)
...@@ -141,7 +141,7 @@ def _kv_scales_extractor( ...@@ -141,7 +141,7 @@ def _kv_scales_extractor(
raise raise
if rank not in rank_scales_map: if rank not in rank_scales_map:
layer_scales_map = {} layer_scales_map: Dict[int, float] = {}
rank_scales_map[rank] = layer_scales_map rank_scales_map[rank] = layer_scales_map
else: else:
raise RuntimeError( raise RuntimeError(
...@@ -222,7 +222,7 @@ def _metadata_extractor(quantized_model_dir: str, ...@@ -222,7 +222,7 @@ def _metadata_extractor(quantized_model_dir: str,
"does not exist.") "does not exist.")
metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json")) metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json"))
result = {} result: Dict[str, Any] = {}
for file in metadata_files: for file in metadata_files:
with open(file) as f: with open(file) as f:
try: try:
......
...@@ -5,7 +5,7 @@ distributively on a multi-nodes cluster. ...@@ -5,7 +5,7 @@ distributively on a multi-nodes cluster.
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
""" """
from typing import Dict from typing import Any, Dict, List
import numpy as np import numpy as np
import ray import ray
...@@ -40,8 +40,8 @@ class LLMPredictor: ...@@ -40,8 +40,8 @@ class LLMPredictor:
# The output is a list of RequestOutput objects that contain the prompt, # The output is a list of RequestOutput objects that contain the prompt,
# generated text, and other information. # generated text, and other information.
outputs = self.llm.generate(batch["text"], sampling_params) outputs = self.llm.generate(batch["text"], sampling_params)
prompt = [] prompt: List[str] = []
generated_text = [] generated_text: List[str] = []
for output in outputs: for output in outputs:
prompt.append(output.prompt) prompt.append(output.prompt)
generated_text.append(' '.join([o.text for o in output.outputs])) generated_text.append(' '.join([o.text for o in output.outputs]))
...@@ -71,7 +71,7 @@ def scheduling_strategy_fn(): ...@@ -71,7 +71,7 @@ def scheduling_strategy_fn():
pg, placement_group_capture_child_tasks=True)) pg, placement_group_capture_child_tasks=True))
resources_kwarg = {} resources_kwarg: Dict[str, Any] = {}
if tensor_parallel_size == 1: if tensor_parallel_size == 1:
# For tensor_parallel_size == 1, we simply set num_gpus=1. # For tensor_parallel_size == 1, we simply set num_gpus=1.
resources_kwarg["num_gpus"] = 1 resources_kwarg["num_gpus"] = 1
......
...@@ -111,7 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml ...@@ -111,7 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml mypy tests --config-file pyproject.toml
# If git diff returns a file that is in the skip list, the file may be checked anyway: # If git diff returns a file that is in the skip list, the file may be checked anyway:
......
from typing import List
import pytest import pytest
from vllm.core.block.block_table import BlockTable from vllm.core.block.block_table import BlockTable
...@@ -28,7 +30,7 @@ def test_allocate_naive(block_size: int, sequence_len: int): ...@@ -28,7 +30,7 @@ def test_allocate_naive(block_size: int, sequence_len: int):
token_ids = list(range(sequence_len)) token_ids = list(range(sequence_len))
num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size))) num_blocks_per_alloc = len(list(chunk_list(token_ids, block_size)))
block_tables = [] block_tables: List[BlockTable] = []
for i in range(5): for i in range(5):
assert allocator.get_num_free_blocks( assert allocator.get_num_free_blocks(
device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc device=Device.GPU) == num_gpu_blocks - i * num_blocks_per_alloc
...@@ -73,7 +75,7 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int): ...@@ -73,7 +75,7 @@ def test_allocate_prefix_caching(block_size: int, sequence_len: int):
num_immutable_blocks_per_alloc = len( num_immutable_blocks_per_alloc = len(
chunked_tokens) - num_mutable_blocks_per_alloc chunked_tokens) - num_mutable_blocks_per_alloc
block_tables = [] block_tables: List[BlockTable] = []
for alloc_i in range(1, 6): for alloc_i in range(1, 6):
block_tables.append( block_tables.append(
...@@ -268,7 +270,7 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int, ...@@ -268,7 +270,7 @@ def test_append_token_ids_correct_content(block_size: int, sequence_len: int,
) )
block_table.allocate(token_ids=token_ids, device=Device.GPU) block_table.allocate(token_ids=token_ids, device=Device.GPU)
appended_so_far = [] appended_so_far: List[int] = []
for append in chunk_list(token_ids_to_append, append_size): for append in chunk_list(token_ids_to_append, append_size):
block_table.append_token_ids(append) block_table.append_token_ids(append)
appended_so_far.extend(append) appended_so_far.extend(append)
......
...@@ -123,7 +123,7 @@ class TestPrefixCachingBlock: ...@@ -123,7 +123,7 @@ class TestPrefixCachingBlock:
num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]: num_empty_trailing_blocks=0) -> List[PrefixCachingBlock]:
"""Helper method which creates a chain of blocks. """Helper method which creates a chain of blocks.
""" """
blocks = [] blocks: List[PrefixCachingBlock] = []
num_blocks = math.ceil( num_blocks = math.ceil(
len(token_ids) / block_size) + num_empty_trailing_blocks len(token_ids) / block_size) + num_empty_trailing_blocks
...@@ -608,7 +608,7 @@ class TestPrefixCachingBlockAllocator: ...@@ -608,7 +608,7 @@ class TestPrefixCachingBlockAllocator:
) -> List[PrefixCachingBlock]: ) -> List[PrefixCachingBlock]:
"""Helper method which creates a chain of blocks. """Helper method which creates a chain of blocks.
""" """
blocks = [] blocks: List[Block] = []
num_blocks = math.ceil(len(token_ids) / block_size) num_blocks = math.ceil(len(token_ids) / block_size)
if num_blocks == 0: if num_blocks == 0:
......
...@@ -483,11 +483,11 @@ def test_chunked_prefill_preempt(): ...@@ -483,11 +483,11 @@ def test_chunked_prefill_preempt():
# The request should be preempted. # The request should be preempted.
scheduler.block_manager.can_append_slots = MagicMock() scheduler.block_manager.can_append_slots = MagicMock()
def cannot_append_second_group(seq_group, num_lookahead_slots): def cannot_append_second_group1(seq_group, num_lookahead_slots):
return seq_group.request_id != "1" return seq_group.request_id != "1"
scheduler.block_manager.can_append_slots.side_effect = ( scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group) cannot_append_second_group1)
# The running prefill is now preempted. # The running prefill is now preempted.
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
...@@ -505,11 +505,11 @@ def test_chunked_prefill_preempt(): ...@@ -505,11 +505,11 @@ def test_chunked_prefill_preempt():
assert seq_group.get_num_uncomputed_tokens() == 30 assert seq_group.get_num_uncomputed_tokens() == 30
# We should be able to run prefill twice as it is chunked. # We should be able to run prefill twice as it is chunked.
def cannot_append_second_group(seq_group, num_lookahead_slots): def cannot_append_second_group2(seq_group, num_lookahead_slots):
return True return True
scheduler.block_manager.can_append_slots.side_effect = ( scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group) cannot_append_second_group2)
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
assert out.num_prefill_groups == 1 assert out.num_prefill_groups == 1
...@@ -530,7 +530,7 @@ def test_chunked_prefill_max_seqs(): ...@@ -530,7 +530,7 @@ def test_chunked_prefill_max_seqs():
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 8
scheduler = Scheduler(scheduler_config, cache_config, None) scheduler = Scheduler(scheduler_config, cache_config, None)
running = [] running: List[SequenceGroup] = []
_, seq_group = create_dummy_prompt("1", prompt_length=65) _, seq_group = create_dummy_prompt("1", prompt_length=65)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
......
import time import time
from collections import deque from collections import deque
from typing import List from typing import Deque, List, Set, Tuple
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest # noqa import pytest # noqa
...@@ -65,7 +65,7 @@ def test_scheduler_abort_seq_group(): ...@@ -65,7 +65,7 @@ def test_scheduler_abort_seq_group():
# Add multiple seq groups to scheduler. # Add multiple seq groups to scheduler.
num_seq_group = 4 num_seq_group = 4
request_ids = set() request_ids: Set[str] = set()
for i in range(num_seq_group): for i in range(num_seq_group):
_, seq_group = create_dummy_prompt(str(i), block_size) _, seq_group = create_dummy_prompt(str(i), block_size)
scheduler.add_seq_group(seq_group) scheduler.add_seq_group(seq_group)
...@@ -347,7 +347,7 @@ def test_prefill_schedule_max_prompt_len(): ...@@ -347,7 +347,7 @@ def test_prefill_schedule_max_prompt_len():
Test prompt longer than max_prompt_len is aborted. Test prompt longer than max_prompt_len is aborted.
""" """
scheduler = initialize_scheduler(max_model_len=30) scheduler = initialize_scheduler(max_model_len=30)
_, seq_group = create_dummy_prompt(0, prompt_length=60) _, seq_group = create_dummy_prompt("0", prompt_length=60)
waiting = deque([seq_group]) waiting = deque([seq_group])
budget = create_token_budget() budget = create_token_budget()
remaining_waiting, output = scheduler._schedule_prefills( remaining_waiting, output = scheduler._schedule_prefills(
...@@ -364,7 +364,7 @@ def test_prefill_schedule_token_budget(): ...@@ -364,7 +364,7 @@ def test_prefill_schedule_token_budget():
Test token budget respected. Test token budget respected.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting = deque() waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=0) budget = create_token_budget(token_budget=0)
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
...@@ -419,7 +419,7 @@ def test_prefill_schedule_max_seqs(): ...@@ -419,7 +419,7 @@ def test_prefill_schedule_max_seqs():
Test max seq respected. Test max seq respected.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting = deque() waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(max_num_seqs=2) budget = create_token_budget(max_num_seqs=2)
for i in range(3): for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
...@@ -453,9 +453,9 @@ def test_prefill_schedule_max_lora(): ...@@ -453,9 +453,9 @@ def test_prefill_schedule_max_lora():
""" """
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config) scheduler = initialize_scheduler(lora_config=lora_config)
waiting = deque() waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget(token_budget=120) budget = create_token_budget(token_budget=120)
curr_loras = set() curr_loras: Set[int] = set()
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), _, seq_group = create_dummy_prompt(str(i),
prompt_length=60, prompt_length=60,
...@@ -499,7 +499,7 @@ def test_prefill_schedule_no_block_manager_capacity(): ...@@ -499,7 +499,7 @@ def test_prefill_schedule_no_block_manager_capacity():
Test sequence cannot be scheduled due to block manager has no capacity. Test sequence cannot be scheduled due to block manager has no capacity.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
waiting = deque() waiting: Deque[SequenceGroup] = deque()
budget = create_token_budget() budget = create_token_budget()
for i in range(3): for i in range(3):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
...@@ -536,7 +536,7 @@ def test_decode_schedule_preempted(): ...@@ -536,7 +536,7 @@ def test_decode_schedule_preempted():
Test decodes cannot be scheduled and preempted. Test decodes cannot be scheduled and preempted.
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
running = deque() running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
for i in range(3): for i in range(3):
...@@ -577,7 +577,7 @@ def test_decode_swap_beam_search(): ...@@ -577,7 +577,7 @@ def test_decode_swap_beam_search():
Test best_of > 1 swap out blocks Test best_of > 1 swap out blocks
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
running = deque() running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
budget = create_token_budget() budget = create_token_budget()
...@@ -628,7 +628,7 @@ def test_schedule_decode_blocks_to_copy_update(): ...@@ -628,7 +628,7 @@ def test_schedule_decode_blocks_to_copy_update():
""" """
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
running = deque() running: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
...@@ -656,10 +656,10 @@ def test_schedule_decode_blocks_to_copy_update(): ...@@ -656,10 +656,10 @@ def test_schedule_decode_blocks_to_copy_update():
def test_schedule_swapped_simple(): def test_schedule_swapped_simple():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped = deque() swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = [] blocks_to_swap_out: List[Tuple[int, int]] = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
...@@ -683,10 +683,10 @@ def test_schedule_swapped_simple(): ...@@ -683,10 +683,10 @@ def test_schedule_swapped_simple():
def test_schedule_swapped_max_token_budget(): def test_schedule_swapped_max_token_budget():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped = deque() swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2): for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
...@@ -717,10 +717,10 @@ def test_schedule_swapped_max_token_budget(): ...@@ -717,10 +717,10 @@ def test_schedule_swapped_max_token_budget():
def test_schedule_swapped_max_seqs(): def test_schedule_swapped_max_seqs():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped = deque() swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(4): for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
...@@ -750,10 +750,10 @@ def test_schedule_swapped_max_seqs(): ...@@ -750,10 +750,10 @@ def test_schedule_swapped_max_seqs():
def test_schedule_swapped_max_loras(): def test_schedule_swapped_max_loras():
lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
scheduler = initialize_scheduler(lora_config=lora_config) scheduler = initialize_scheduler(lora_config=lora_config)
swapped = deque() swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = set() curr_loras: Set[int] = set()
blocks_to_swap_out = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), _, seq_group = create_dummy_prompt(str(i),
prompt_length=60, prompt_length=60,
...@@ -779,10 +779,10 @@ def test_schedule_swapped_max_loras(): ...@@ -779,10 +779,10 @@ def test_schedule_swapped_max_loras():
def test_schedule_swapped_cannot_swap_in(): def test_schedule_swapped_cannot_swap_in():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped = deque() swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2): for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
...@@ -806,10 +806,10 @@ def test_schedule_swapped_cannot_swap_in(): ...@@ -806,10 +806,10 @@ def test_schedule_swapped_cannot_swap_in():
def test_infeasible_swap(): def test_infeasible_swap():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped = deque() swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = [] blocks_to_swap_out: List[Tuple[int, int]] = []
for _ in range(2): for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
...@@ -834,13 +834,13 @@ def test_infeasible_swap(): ...@@ -834,13 +834,13 @@ def test_infeasible_swap():
def test_schedule_swapped_blocks_to_copy(): def test_schedule_swapped_blocks_to_copy():
scheduler = initialize_scheduler() scheduler = initialize_scheduler()
swapped = deque() swapped: Deque[SequenceGroup] = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out = [] blocks_to_swap_out: List[Tuple[int, int]] = []
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) swapped.append(seq_group)
......
import time import time
from typing import Iterable, Optional, Tuple from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple
from vllm import SamplingParams from vllm import SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -46,7 +48,7 @@ def create_dummy_prompt_encoder_decoder( ...@@ -46,7 +48,7 @@ def create_dummy_prompt_encoder_decoder(
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False, use_beam_search: bool = False,
best_of: int = 1, best_of: int = 1,
) -> Tuple[Sequence, SequenceGroup]: ) -> Tuple[Sequence, Sequence, SequenceGroup]:
if not block_size: if not block_size:
block_size = decoder_prompt_length block_size = decoder_prompt_length
...@@ -86,7 +88,7 @@ def create_dummy_prompt_encoder_decoder( ...@@ -86,7 +88,7 @@ def create_dummy_prompt_encoder_decoder(
def create_seq_group( def create_seq_group(
seq_prompt_len: int = 1024, seq_prompt_len: int = 1024,
seq_output_lens: Iterable[int] = (128, ), seq_output_lens: GenericSequence[int] = (128, ),
request_id: str = '0', request_id: str = '0',
seq_id_start: int = 0, seq_id_start: int = 0,
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
...@@ -98,7 +100,7 @@ def create_seq_group( ...@@ -98,7 +100,7 @@ def create_seq_group(
prompt_token_ids = [0] * seq_prompt_len prompt_token_ids = [0] * seq_prompt_len
seqs = [] seqs: List[Sequence] = []
for seq_id_offset, output_len in enumerate(seq_output_lens): for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence( seq = Sequence(
seq_id=seq_id_start + seq_id_offset, seq_id=seq_id_start + seq_id_offset,
...@@ -125,7 +127,7 @@ def create_seq_group( ...@@ -125,7 +127,7 @@ def create_seq_group(
def create_seq_group_encoder_decoder( def create_seq_group_encoder_decoder(
seq_prompt_len: int = 1024, seq_prompt_len: int = 1024,
seq_output_lens: Iterable[int] = (128, ), seq_output_lens: GenericSequence[int] = (128, ),
request_id: str = '0', request_id: str = '0',
seq_id_start: int = 0, seq_id_start: int = 0,
sampling_params: Optional[SamplingParams] = None) -> SequenceGroup: sampling_params: Optional[SamplingParams] = None) -> SequenceGroup:
......
import multiprocessing import multiprocessing
import os import os
from typing import Dict, List
import pytest import pytest
import torch import torch
...@@ -17,9 +18,9 @@ from vllm.utils import update_environment_variables ...@@ -17,9 +18,9 @@ from vllm.utils import update_environment_variables
def distributed_run(fn, world_size): def distributed_run(fn, world_size):
number_of_processes = world_size number_of_processes = world_size
processes = [] processes: List[multiprocessing.Process] = []
for i in range(number_of_processes): for i in range(number_of_processes):
env = {} env: Dict[str, str] = {}
env['RANK'] = str(i) env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i) env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes) env['WORLD_SIZE'] = str(number_of_processes)
......
...@@ -6,7 +6,7 @@ from vllm.utils import cuda_device_count_stateless ...@@ -6,7 +6,7 @@ from vllm.utils import cuda_device_count_stateless
@ray.remote @ray.remote
class _CUDADeviceCountStatelessTestActor(): class _CUDADeviceCountStatelessTestActor:
def get_count(self): def get_count(self):
return cuda_device_count_stateless() return cuda_device_count_stateless()
...@@ -22,7 +22,8 @@ def test_cuda_device_count_stateless(): ...@@ -22,7 +22,8 @@ def test_cuda_device_count_stateless():
"""Test that cuda_device_count_stateless changes return value if """Test that cuda_device_count_stateless changes return value if
CUDA_VISIBLE_DEVICES is changed.""" CUDA_VISIBLE_DEVICES is changed."""
actor = _CUDADeviceCountStatelessTestActor.options(num_gpus=2).remote() actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
num_gpus=2).remote()
assert sorted(ray.get( assert sorted(ray.get(
actor.get_cuda_visible_devices.remote()).split(",")) == ["0", "1"] actor.get_cuda_visible_devices.remote()).split(",")) == ["0", "1"]
assert ray.get(actor.get_count.remote()) == 2 assert ray.get(actor.get_count.remote()) == 2
......
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import re import re
from typing import List
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
...@@ -453,7 +454,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, ...@@ -453,7 +454,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
stream=True) stream=True)
chunks = [] chunks: List[str] = []
finish_reason_count = 0 finish_reason_count = 0
async for chunk in stream: async for chunk in stream:
chunks.append(chunk.choices[0].text) chunks.append(chunk.choices[0].text)
...@@ -499,7 +500,7 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str): ...@@ -499,7 +500,7 @@ async def test_chat_streaming(client: openai.AsyncOpenAI, model_name: str):
temperature=0.0, temperature=0.0,
stream=True, stream=True,
) )
chunks = [] chunks: List[str] = []
finish_reason_count = 0 finish_reason_count = 0
async for chunk in stream: async for chunk in stream:
delta = chunk.choices[0].delta delta = chunk.choices[0].delta
......
...@@ -72,27 +72,27 @@ def ref_single_query_cached_kv_attention( ...@@ -72,27 +72,27 @@ def ref_single_query_cached_kv_attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs = query.shape[0] num_seqs = query.shape[0]
block_tables = block_tables.cpu().tolist() block_tables_lst = block_tables.cpu().tolist()
seq_lens = seq_lens.cpu().tolist() seq_lens_lst = seq_lens.cpu().tolist()
for i in range(num_seqs): for i in range(num_seqs):
q = query[i].unsqueeze(0) q = query[i].unsqueeze(0)
block_table = block_tables[i] block_table = block_tables_lst[i]
seq_len = int(seq_lens[i]) seq_len = int(seq_lens_lst[i])
keys = [] keys_lst: List[torch.Tensor] = []
values = [] values_lst: List[torch.Tensor] = []
for j in range(seq_len): for j in range(seq_len):
block_number = int(block_table[j // block_size]) block_number = int(block_table[j // block_size])
block_offset = j % block_size block_offset = j % block_size
k = key_cache[block_number, :, :, block_offset, :] k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_kv_heads, head_size) k = k.reshape(num_kv_heads, head_size)
keys.append(k) keys_lst.append(k)
v = value_cache[block_number, :, :, block_offset] v = value_cache[block_number, :, :, block_offset]
values.append(v) values_lst.append(v)
keys = torch.stack(keys, dim=0) keys = torch.stack(keys_lst, dim=0)
values = torch.stack(values, dim=0) values = torch.stack(values_lst, dim=0)
if num_queries_per_kv > 1: if num_queries_per_kv > 1:
# Handle MQA and GQA # Handle MQA and GQA
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
...@@ -157,14 +157,15 @@ def test_paged_attention( ...@@ -157,14 +157,15 @@ def test_paged_attention(
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = [] block_tables_lst: List[List[int]] = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
random.randint(0, NUM_BLOCKS - 1) random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq) for _ in range(max_num_blocks_per_seq)
] ]
block_tables.append(block_table) block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
...@@ -283,7 +284,7 @@ def ref_multi_query_kv_attention( ...@@ -283,7 +284,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype, dtype: torch.dtype,
) -> torch.Tensor: ) -> torch.Tensor:
num_seqs = len(cu_seq_lens) - 1 num_seqs = len(cu_seq_lens) - 1
ref_outputs = [] ref_outputs: List[torch.Tensor] = []
for i in range(num_seqs): for i in range(num_seqs):
start_idx = cu_seq_lens[i] start_idx = cu_seq_lens[i]
end_idx = cu_seq_lens[i + 1] end_idx = cu_seq_lens[i + 1]
...@@ -303,8 +304,8 @@ def ref_multi_query_kv_attention( ...@@ -303,8 +304,8 @@ def ref_multi_query_kv_attention(
attn_mask=attn_mask, attn_mask=attn_mask,
) )
ref_outputs.append(ref_output) ref_outputs.append(ref_output)
ref_output = torch.cat(ref_outputs, dim=0)
return ref_output return torch.cat(ref_outputs, dim=0)
# TODO(woosuk): Add tests for USE_ALIBI=True. # TODO(woosuk): Add tests for USE_ALIBI=True.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment