Unverified Commit b893d661 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix per file ruff ignores related to simplification (#26259)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6b6e9877
......@@ -101,40 +101,6 @@ include = ["vllm*"]
"vllm/v1/engine/utils.py" = ["E501"]
"vllm/v1/utils.py" = ["E501"]
"vllm/v1/worker/gpu_model_runner.py" = ["E501"]
## Simplification rules
"tests/distributed/test_expert_placement.py" = ["SIM108"]
"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"]
"tests/kernels/attention/test_flashmla.py" = ["SIM108"]
"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"]
"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"]
"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"]
"tests/kernels/test_onednn.py" = ["SIM108"]
"tests/kernels/utils.py" = ["SIM108"]
"tests/multimodal/test_processing.py" = ["SIM108"]
"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"]
"vllm/distributed/parallel_state.py" = ["SIM108"]
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
"vllm/entrypoints/llm.py" = ["SIM108"]
"vllm/executor/ray_distributed_executor.py" = ["SIM108", "SIM112"]
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"]
"vllm/model_executor/layers/layernorm.py" = ["SIM108"]
"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"]
"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"]
"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"]
"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"]
"vllm/utils/__init__.py" = ["SIM108"]
"vllm/v1/sample/ops/bad_words.py" = ["SIM108"]
"vllm/v1/sample/rejection_sampler.py" = ["SIM108"]
"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"]
"vllm/_custom_ops.py" = ["SIM108"]
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
## Loop variable binding issues
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
# End of temporary ignores
[tool.ruff.lint]
......
......@@ -12,10 +12,7 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts)
base_experts = global_num_experts // ep_size
remainder = global_num_experts % ep_size
if ep_rank < remainder:
local_num_experts = base_experts + 1
else:
local_num_experts = base_experts
local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
# Expected expert IDs for this rank in round_robin pattern
# For non-divisible cases, ranks with extra experts start earlier
......
......@@ -66,10 +66,7 @@ def test_cutlass_mla_decode(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
):
device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
......
......@@ -52,10 +52,7 @@ def test_flash_mla(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
):
device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
......
......@@ -33,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
# More efficient implementation
# Convert decay factors to matrix form
if ed.dim() == 1:
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)
decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
for b in range(B):
for step in range(S):
......
......@@ -705,10 +705,7 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
if shared_experts is not None:
shared_output = shared_experts(a)
else:
shared_output = None
shared_output = shared_experts(a) if shared_experts is not None else None
torch_output = torch_experts(
a,
......
......@@ -88,10 +88,7 @@ def cutlass_fp8_gemm_helper(
# make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b = scale_b.t().contiguous().t()
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
......@@ -122,10 +119,7 @@ def cutlass_int8_gemm_helper(
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
......
......@@ -84,10 +84,7 @@ def onednn_int8_gemm_test_helper(
azp = None
azp_adj = None
if use_bias:
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
handler = ops.create_onednn_scaled_mm(
b,
......
......@@ -963,13 +963,9 @@ def make_test_metadata(
None if encoder_seq_lens is None else (sum(encoder_seq_lens))
)
if cross_test_params is None:
cross_kv_mmap = None
else:
# Encoder/decoder or encoder-only models only:
# * Extract *cross-attention* slot_mapping and block table
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap
# For encoder/decoder or encoder-only models only, extract *cross-attention*
# slot_mapping and block table (kv_mmap)
cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
attn_backend_obj = make_backend(attn_backend.name)
......
......@@ -941,10 +941,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
profiler = MultiModalProfiler(processor)
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="At most")
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx:
profiler.get_decoder_dummy_data(
......@@ -985,10 +982,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
else:
mm_data = {"image": [image] * num_images}
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="At most")
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx:
processor.apply(
......
......@@ -58,7 +58,7 @@ if __name__ == "__main__":
assert args.phase in profile_data, (
f"Cannot find phase {args.phase} in profile data. Choose one among"
f"{[x for x in profile_data.keys() if 'prefill' in x or 'decode' in x]}"
f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}"
) # noqa
if args.table == "summary":
......
......@@ -2370,10 +2370,7 @@ class CPUDNNLGEMMHandler:
torch.ops._C.release_dnnl_matmul_handler(self.handler)
if hasattr(torch.ops._C, "create_onednn_mm_handler"):
_supports_onednn = True
else:
_supports_onednn = False
_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
def is_onednn_acl_supported():
......
......@@ -52,12 +52,9 @@ def reshape_and_cache_kernel_flash(
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
)
if FP8_KV_CACHE:
if key_load.dtype.is_fp8():
key_tile = key_load
else:
# tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty
key_tile = key_load / tl.load(k_scale)
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
else:
key_tile = key_load
......
......@@ -1097,10 +1097,7 @@ def init_distributed_environment(
if local_rank == -1:
# local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank
if distributed_init_method == "env://":
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
......
......@@ -1310,10 +1310,7 @@ def _parse_chat_message_content_part(
modality = None
if part_type == "image_pil":
if content is not None:
image_content = cast(Image.Image, content)
else:
image_content = None
image_content = cast(Image.Image, content) if content is not None else None
mm_parser.parse_image_pil(image_content, uuid)
modality = "image"
elif part_type in ("image_url", "input_image"):
......
......@@ -1018,10 +1018,7 @@ class LLM:
pooling_task = "encode"
if pooling_task is None:
if "embed" in self.supported_tasks:
pooling_task = "embed"
else:
pooling_task = "encode"
pooling_task = "embed" if "embed" in self.supported_tasks else "encode"
logger.warning_once(
"`LLM.encode` is currently using `pooling_task = %s`.\n"
......
......@@ -458,10 +458,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
else:
serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data))
if self.use_v1:
output = outputs[0]
else:
output = self.output_decoder.decode(outputs[0])
output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0])
return output
def _run_workers(
......@@ -482,10 +479,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs
"""
if isinstance(method, str):
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
del method
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
......@@ -573,8 +567,9 @@ class RayDistributedExecutor(DistributedExecutorBase):
from ray.dag import InputNode, MultiOutputNode
logger.info(
"RAY_CGRAPH_get_timeout is set to %s", os.environ["RAY_CGRAPH_get_timeout"]
) # noqa: SIM112
"RAY_CGRAPH_get_timeout is set to %s",
os.environ["RAY_CGRAPH_get_timeout"], # noqa: SIM112
)
logger.info(
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
......
......@@ -439,10 +439,7 @@ def mean_dim(
output = torch.empty(output_shape, dtype=dtype, device=input.device)
# Reshape output for kernel
if keepdim:
output_2d = output.reshape(M, 1, K).squeeze(1)
else:
output_2d = output.reshape(M, K)
output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K)
# Launch kernel
grid = (M * K,)
......
......@@ -151,10 +151,7 @@ def chunk_fwd_o(
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
if FLA_GDN_FIX_BT:
BT = 64
else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
......
......@@ -1746,10 +1746,7 @@ def fused_experts_impl(
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
if use_mxfp4_w4a4:
# Weight has to be dequantized for mxfp4 emulation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment