Unverified Commit 823ab796 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Update `pre-commit` hooks (#12475)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6116ca8c
......@@ -627,8 +627,8 @@ def attn_fwd(
causal_start_idx,
dtype=tl.int32)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = (mask_m_offsets[:, None] >=
out_mask_boundary[None, :])
out_ptrs_mask = (mask_m_offsets[:, None]
>= out_mask_boundary[None, :])
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
......
import os
from contextlib import contextmanager
from functools import lru_cache
from functools import cache
from typing import Generator, Optional, Type
import torch
......@@ -100,7 +100,7 @@ def get_attn_backend(
)
@lru_cache(maxsize=None)
@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
......
......@@ -67,7 +67,8 @@ _RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
task: runner
for runner, tasks in _RUNNER_TASKS.items() for task in tasks
for runner, tasks in _RUNNER_TASKS.items()
for task in tasks
}
HfOverrides = Union[Dict[str, Any], Callable[[PretrainedConfig],
......@@ -1976,8 +1977,8 @@ class SpeculativeConfig:
"typical_acceptance_sampler.")
if (self.draft_token_acceptance_method != 'rejection_sampler'
and self.draft_token_acceptance_method !=
'typical_acceptance_sampler'):
and self.draft_token_acceptance_method
!= 'typical_acceptance_sampler'):
raise ValueError(
"Expected draft_token_acceptance_method to be either "
"rejection_sampler or typical_acceptance_sampler. Instead it "
......
......@@ -34,9 +34,10 @@ class RefCounter(RefCounterProtocol):
def __init__(self, all_block_indices: Iterable[BlockId]):
deduped = set(all_block_indices)
self._refcounts: Dict[BlockId,
RefCount] = {index: 0
for index in deduped}
self._refcounts: Dict[BlockId, RefCount] = {
index: 0
for index in deduped
}
def incr(self, block_id: BlockId) -> RefCount:
assert block_id in self._refcounts
......
......@@ -136,8 +136,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
device=Device.GPU)
# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
if (self.num_total_gpu_blocks - num_required_blocks
< self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
......
......@@ -988,8 +988,8 @@ class Scheduler:
waiting_queue.popleft()
continue
if (budget.num_batched_tokens >=
self.scheduler_config.max_num_batched_tokens):
if (budget.num_batched_tokens
>= self.scheduler_config.max_num_batched_tokens):
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
......@@ -1096,8 +1096,8 @@ class Scheduler:
running_scheduled.swapped_out) == 0:
swapped_in = self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
......@@ -1189,8 +1189,8 @@ class Scheduler:
curr_loras,
enable_chunking=True)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert (budget.num_batched_tokens
<= self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
......@@ -1358,8 +1358,8 @@ class Scheduler:
# NOTE: We use get_len instead of get_prompt_len because when
# a sequence is preempted, prefill includes previous generated
# output tokens.
if (token_chunk_size + num_computed_tokens <
seqs[0].data.get_len()):
if (token_chunk_size + num_computed_tokens
< seqs[0].data.get_len()):
do_sample = False
# It assumes the scheduled_seq_groups is ordered by
......@@ -1625,10 +1625,9 @@ class Scheduler:
if self.scheduler_config.delay_factor > 0 and self.waiting:
earliest_arrival_time = min(
[e.metrics.arrival_time for e in self.waiting])
passed_delay = (
(now - earliest_arrival_time) >
(self.scheduler_config.delay_factor * self.last_prompt_latency)
or not self.running)
passed_delay = ((now - earliest_arrival_time)
> (self.scheduler_config.delay_factor *
self.last_prompt_latency) or not self.running)
else:
passed_delay = True
return passed_delay
......
......@@ -352,8 +352,8 @@ class MessageQueue:
sched_yield()
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
if (time.monotonic() - start_time
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
......@@ -410,8 +410,8 @@ class MessageQueue:
sched_yield()
# if we wait for a long time, log a message
if (time.monotonic() - start_time >
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
if (time.monotonic() - start_time
> VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
logger.debug("No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
......
......@@ -1014,8 +1014,8 @@ def initialize_model_parallel(
backend = backend or torch.distributed.get_backend(
get_world_group().device_group)
if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size):
if (world_size
!= tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
......@@ -1069,8 +1069,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
return
if all([
vllm_config.kv_transfer_config.need_kv_parallel_group,
_KV_TRANSFER is None
vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
......
......@@ -3,7 +3,7 @@ import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from functools import lru_cache, partial
from functools import cache, lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Optional, Tuple, TypeVar, Union, cast)
......@@ -377,7 +377,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return self._model_config.allowed_local_media_path
@staticmethod
@lru_cache(maxsize=None)
@cache
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
return tokenizer.decode(token_index)
......
......@@ -522,11 +522,10 @@ class OpenAIServingCompletion(OpenAIServing):
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(
top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids):
self._get_decoded_token(top_lp[1],
top_lp[0],
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
......
......@@ -62,8 +62,8 @@ class Granite20bFCToolParser(ToolParser):
start_of_json = match.end()
# end_index == the start of the next function call
# (if exists)
next_function_call_start = (matches[i + 1].start()
if i + 1 < len(matches) else None)
next_function_call_start = (matches[i + 1].start() if i +
1 < len(matches) else None)
raw_function_calls.append(
dec.raw_decode(
......
......@@ -220,8 +220,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1], ].copy_(embeddings_tensor, non_blocking=True)
index,
:embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
......@@ -1024,8 +1026,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_b.T, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index, :embeddings_tensor.shape[0], :embeddings_tensor.
shape[1], ] = embeddings_tensor
index,
:embeddings_tensor.shape[0],
:embeddings_tensor.shape[1],
] = embeddings_tensor
def _get_logits(
self,
......
......@@ -75,8 +75,9 @@ class LoRAModel(AdapterModel):
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
assert (lora_model_id >
0), f"a valid lora id should be greater than 0, got {self.id}"
assert (
lora_model_id
> 0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: Dict[str, LoRALayerWeights] = loras
......
......@@ -136,9 +136,8 @@ def _sgmv_expand_kernel(
c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
M = tl.load(seq_lens + cur_batch)
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] <
(cur_slice_start + curr_N))
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (
offset_cn[None, :] < (cur_slice_start + curr_N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
......
......@@ -114,8 +114,8 @@ def _sgmv_shrink_kernel(
slice_id * output_d0_stride)
c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] <
(cur_seq_start + M)) & (offset_cn[None, :] < N)
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :]
< N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
......
......@@ -73,12 +73,12 @@ class MPLinearKernel(ABC):
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
self, layer: torch.nn.Module) -> Tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
......
......@@ -48,13 +48,13 @@ class ScaledMMLinearKernel(ABC):
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_zp
Optional[torch.Tensor], # azp_adj
]:
self, layer: torch.nn.Module) -> Tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
Optional[torch.Tensor], # input_scale,
Optional[torch.Tensor], # input_zp
Optional[torch.Tensor], # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
......
......@@ -72,9 +72,10 @@ def block_quant_to_tensor_quant(
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [[
x_dq_block[j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k), ]
for i in range(k_tiles)
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] for i in range(k_tiles)
] for j in range(n_tiles)]
for i in range(k_tiles):
......
......@@ -73,8 +73,8 @@ def requantize_with_max_scale(
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
torch.float8_e4m3fn).min)
unfused_module_in_checkpoint = (weight_scale[-1]
> torch.finfo(torch.float8_e4m3fn).min)
# If unfused checkpoint, need requanize with the single scale.
if unfused_module_in_checkpoint:
......
......@@ -716,9 +716,10 @@ def _sample_with_torch(
tensors required for Pythonization
'''
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_seq_group_ids: Dict[SamplingType, List[int]] = {
t: []
for t in SamplingType
}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment