Unverified Commit 823ab796 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update `pre-commit` hooks (#12475)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6116ca8c
...@@ -627,8 +627,8 @@ def attn_fwd( ...@@ -627,8 +627,8 @@ def attn_fwd(
causal_start_idx, causal_start_idx,
dtype=tl.int32) dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >= out_ptrs_mask = (mask_m_offsets[:, None]
out_mask_boundary[None, :]) >= out_mask_boundary[None, :])
z = 0.0 z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE # write back LSE
......
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import lru_cache from functools import cache
from typing import Generator, Optional, Type from typing import Generator, Optional, Type
import torch import torch
...@@ -100,7 +100,7 @@ def get_attn_backend( ...@@ -100,7 +100,7 @@ def get_attn_backend(
) )
@lru_cache(maxsize=None) @cache
def _cached_get_attn_backend( def _cached_get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
......
...@@ -67,7 +67,8 @@ _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = { ...@@ -67,7 +67,8 @@ _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = { _TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
task: runner task: runner
for runner, tasks in _RUNNER_TASKS.items() for task in tasks for runner, tasks in _RUNNER_TASKS.items()
for task in tasks
} }
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig], HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
...@@ -1976,8 +1977,8 @@ class SpeculativeConfig: ...@@ -1976,8 +1977,8 @@ class SpeculativeConfig:
"typical_acceptance_sampler.") "typical_acceptance_sampler.")
if (self.draft_token_acceptance_method != 'rejection_sampler' if (self.draft_token_acceptance_method != 'rejection_sampler'
and self.draft_token_acceptance_method != and self.draft_token_acceptance_method
'typical_acceptance_sampler'): != 'typical_acceptance_sampler'):
raise ValueError( raise ValueError(
"Expected draft_token_acceptance_method to be either " "Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it " "rejection_sampler or typical_acceptance_sampler. Instead it "
......
...@@ -34,9 +34,10 @@ class RefCounter(RefCounterProtocol): ...@@ -34,9 +34,10 @@ class RefCounter(RefCounterProtocol):
def __init__(self, all_block_indices: Iterable[BlockId]): def __init__(self, all_block_indices: Iterable[BlockId]):
deduped = set(all_block_indices) deduped = set(all_block_indices)
self._refcounts: Dict[BlockId, self._refcounts: Dict[BlockId, RefCount] = {
RefCount] = {index: 0 index: 0
for index in deduped} for index in deduped
}
def incr(self, block_id: BlockId) -> RefCount: def incr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts assert block_id in self._refcounts
......
...@@ -136,8 +136,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): ...@@ -136,8 +136,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
device=Device.GPU) device=Device.GPU)
# Use watermark to avoid frequent cache eviction. # Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks < if (self.num_total_gpu_blocks - num_required_blocks
self.watermark_blocks): < self.watermark_blocks):
return AllocStatus.NEVER return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK return AllocStatus.OK
......
...@@ -988,8 +988,8 @@ class Scheduler: ...@@ -988,8 +988,8 @@ class Scheduler:
waiting_queue.popleft() waiting_queue.popleft()
continue continue
if (budget.num_batched_tokens >= if (budget.num_batched_tokens
self.scheduler_config.max_num_batched_tokens): >= self.scheduler_config.max_num_batched_tokens):
# We've reached the budget limit - since there might be # We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break # continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills. # to avoid scheduling any new prefills.
...@@ -1096,8 +1096,8 @@ class Scheduler: ...@@ -1096,8 +1096,8 @@ class Scheduler:
running_scheduled.swapped_out) == 0: running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras) swapped_in = self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens <= assert (budget.num_batched_tokens
self.scheduler_config.max_num_batched_tokens) <= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests. # Update waiting requests.
...@@ -1189,8 +1189,8 @@ class Scheduler: ...@@ -1189,8 +1189,8 @@ class Scheduler:
curr_loras, curr_loras,
enable_chunking=True) enable_chunking=True)
assert (budget.num_batched_tokens <= assert (budget.num_batched_tokens
self.scheduler_config.max_num_batched_tokens) <= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests. # Update waiting requests.
...@@ -1358,8 +1358,8 @@ class Scheduler: ...@@ -1358,8 +1358,8 @@ class Scheduler:
# NOTE: We use get_len instead of get_prompt_len because when # NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated # a sequence is preempted, prefill includes previous generated
# output tokens. # output tokens.
if (token_chunk_size + num_computed_tokens < if (token_chunk_size + num_computed_tokens
seqs[0].data.get_len()): < seqs[0].data.get_len()):
do_sample = False do_sample = False
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
...@@ -1625,10 +1625,9 @@ class Scheduler: ...@@ -1625,10 +1625,9 @@ class Scheduler:
if self.scheduler_config.delay_factor > 0 and self.waiting: if self.scheduler_config.delay_factor > 0 and self.waiting:
earliest_arrival_time = min( earliest_arrival_time = min(
[e.metrics.arrival_time for e in self.waiting]) [e.metrics.arrival_time for e in self.waiting])
passed_delay = ( passed_delay = ((now - earliest_arrival_time)
(now - earliest_arrival_time) > > (self.scheduler_config.delay_factor *
(self.scheduler_config.delay_factor * self.last_prompt_latency) self.last_prompt_latency) or not self.running)
or not self.running)
else: else:
passed_delay = True passed_delay = True
return passed_delay return passed_delay
......
...@@ -352,8 +352,8 @@ class MessageQueue: ...@@ -352,8 +352,8 @@ class MessageQueue:
sched_yield() sched_yield()
# if we wait for a long time, log a message # if we wait for a long time, log a message
if (time.monotonic() - start_time > if (time.monotonic() - start_time
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ", logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL) VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1 n_warning += 1
...@@ -410,8 +410,8 @@ class MessageQueue: ...@@ -410,8 +410,8 @@ class MessageQueue:
sched_yield() sched_yield()
# if we wait for a long time, log a message # if we wait for a long time, log a message
if (time.monotonic() - start_time > if (time.monotonic() - start_time
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ", logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL) VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1 n_warning += 1
......
...@@ -1014,8 +1014,8 @@ def initialize_model_parallel( ...@@ -1014,8 +1014,8 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend( backend = backend or torch.distributed.get_backend(
get_world_group().device_group) get_world_group().device_group)
if (world_size != if (world_size
tensor_model_parallel_size * pipeline_model_parallel_size): != tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError( raise RuntimeError(
f"world_size ({world_size}) is not equal to " f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
...@@ -1069,8 +1069,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ...@@ -1069,8 +1069,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
return return
if all([ if all([
vllm_config.kv_transfer_config.need_kv_parallel_group, vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
_KV_TRANSFER is None is None
]): ]):
_KV_TRANSFER = kv_transfer.KVTransferAgent( _KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank, rank=get_world_group().rank,
......
...@@ -3,7 +3,7 @@ import codecs ...@@ -3,7 +3,7 @@ import codecs
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict, deque from collections import defaultdict, deque
from functools import lru_cache, partial from functools import cache, lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Optional, Tuple, TypeVar, Union, cast) Literal, Optional, Tuple, TypeVar, Union, cast)
...@@ -377,7 +377,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -377,7 +377,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self._model_config.allowed_local_media_path return self._model_config.allowed_local_media_path
@staticmethod @staticmethod
@lru_cache(maxsize=None) @cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index) return tokenizer.decode(token_index)
......
...@@ -522,8 +522,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -522,8 +522,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({ out_top_logprobs.append({
# Convert float("-inf") to the # Convert float("-inf") to the
# JSON-serializable float that OpenAI uses # JSON-serializable float that OpenAI uses
self._get_decoded_token( self._get_decoded_token(top_lp[1],
top_lp[1],
top_lp[0], top_lp[0],
tokenizer, tokenizer,
return_as_token_id=self.return_tokens_as_token_ids): return_as_token_id=self.return_tokens_as_token_ids):
......
...@@ -62,8 +62,8 @@ class Granite20bFCToolParser(ToolParser): ...@@ -62,8 +62,8 @@ class Granite20bFCToolParser(ToolParser):
start_of_json = match.end() start_of_json = match.end()
# end_index == the start of the next function call # end_index == the start of the next function call
# (if exists) # (if exists)
next_function_call_start = (matches[i + 1].start() next_function_call_start = (matches[i + 1].start() if i +
if i + 1 < len(matches) else None) 1 < len(matches) else None)
raw_function_calls.append( raw_function_calls.append(
dec.raw_decode( dec.raw_decode(
......
...@@ -220,8 +220,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): ...@@ -220,8 +220,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
if embeddings_tensor is not None: if embeddings_tensor is not None:
self.embeddings_tensors[ self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor. index,
shape[1], ].copy_(embeddings_tensor, non_blocking=True) :embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None: if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy # TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part # everything, just the modified part
...@@ -1024,8 +1026,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): ...@@ -1024,8 +1026,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_b.T, non_blocking=True) lora_b.T, non_blocking=True)
if embeddings_tensor is not None: if embeddings_tensor is not None:
self.embeddings_tensors[ self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor. index,
shape[1], ] = embeddings_tensor :embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
] = embeddings_tensor
def _get_logits( def _get_logits(
self, self,
......
...@@ -75,8 +75,9 @@ class LoRAModel(AdapterModel): ...@@ -75,8 +75,9 @@ class LoRAModel(AdapterModel):
# Scaling factor for long context lora model. None if it is not # Scaling factor for long context lora model. None if it is not
# fine tuned for the long context. # fine tuned for the long context.
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
assert (lora_model_id > assert (
0), f"a valid lora id should be greater than 0, got {self.id}" lora_model_id
> 0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras self.loras: Dict[str, LoRALayerWeights] = loras
......
...@@ -136,9 +136,8 @@ def _sgmv_expand_kernel( ...@@ -136,9 +136,8 @@ def _sgmv_expand_kernel(
c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride + c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride) offset_cn[None, :] * output_d1_stride)
M = tl.load(seq_lens + cur_batch) M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] < c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (
(cur_seq_start + M)) & (offset_cn[None, :] < offset_cn[None, :] < (cur_slice_start + curr_N))
(cur_slice_start + curr_N))
if ADD_INPUTS: if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask) tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out tiled_c += tiled_out
......
...@@ -114,8 +114,8 @@ def _sgmv_shrink_kernel( ...@@ -114,8 +114,8 @@ def _sgmv_shrink_kernel(
slice_id * output_d0_stride) slice_id * output_d0_stride)
c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[ c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :]
(cur_seq_start + M)) & (offset_cn[None, :] < N) < N)
accumulator *= scaling accumulator *= scaling
# handles write-back with reduction-splitting # handles write-back with reduction-splitting
if SPLIT_K == 1: if SPLIT_K == 1:
......
...@@ -73,8 +73,8 @@ class MPLinearKernel(ABC): ...@@ -73,8 +73,8 @@ class MPLinearKernel(ABC):
torch.nn.Parameter(new_param.data, requires_grad=False)) torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params( def _get_weight_params(
self, layer: torch.nn.Module self, layer: torch.nn.Module) -> Tuple[
) -> Tuple[torch.Tensor, # w_q torch.Tensor, # w_q
torch.Tensor, # w_s torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp, Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx Optional[torch.Tensor] # w_gidx
......
...@@ -48,8 +48,8 @@ class ScaledMMLinearKernel(ABC): ...@@ -48,8 +48,8 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError raise NotImplementedError
def _get_weight_params( def _get_weight_params(
self, layer: torch.nn.Module self, layer: torch.nn.Module) -> Tuple[
) -> Tuple[torch.Tensor, # weight torch.Tensor, # weight
torch.Tensor, # weight_scale torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale, Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_zp Optional[torch.Tensor], # input_zp
......
...@@ -72,9 +72,10 @@ def block_quant_to_tensor_quant( ...@@ -72,9 +72,10 @@ def block_quant_to_tensor_quant(
x_dq_block = x_q_block.to(torch.float32) x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [[ x_dq_block_tiles = [[
x_dq_block[j * block_n:min((j + 1) * block_n, n), x_dq_block[
i * block_k:min((i + 1) * block_k, k), ] j * block_n:min((j + 1) * block_n, n),
for i in range(k_tiles) i * block_k:min((i + 1) * block_k, k),
] for i in range(k_tiles)
] for j in range(n_tiles)] ] for j in range(n_tiles)]
for i in range(k_tiles): for i in range(k_tiles):
......
...@@ -73,8 +73,8 @@ def requantize_with_max_scale( ...@@ -73,8 +73,8 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since) # from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale. # we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo( unfused_module_in_checkpoint = (weight_scale[-1]
torch.float8_e4m3fn).min) > torch.finfo(torch.float8_e4m3fn).min)
# If unfused checkpoint, need requanize with the single scale. # If unfused checkpoint, need requanize with the single scale.
if unfused_module_in_checkpoint: if unfused_module_in_checkpoint:
......
...@@ -716,9 +716,10 @@ def _sample_with_torch( ...@@ -716,9 +716,10 @@ def _sample_with_torch(
tensors required for Pythonization tensors required for Pythonization
''' '''
categorized_seq_group_ids: Dict[SamplingType, categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
List[int]] = {t: [] t: []
for t in SamplingType} for t in SamplingType
}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment