Unverified Commit b893d661 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix per file ruff ignores related to simplification (#26259)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6b6e9877
...@@ -101,40 +101,6 @@ include = ["vllm*"] ...@@ -101,40 +101,6 @@ include = ["vllm*"]
"vllm/v1/engine/utils.py" = ["E501"] "vllm/v1/engine/utils.py" = ["E501"]
"vllm/v1/utils.py" = ["E501"] "vllm/v1/utils.py" = ["E501"]
"vllm/v1/worker/gpu_model_runner.py" = ["E501"] "vllm/v1/worker/gpu_model_runner.py" = ["E501"]
## Simplification rules
"tests/distributed/test_expert_placement.py" = ["SIM108"]
"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"]
"tests/kernels/attention/test_flashmla.py" = ["SIM108"]
"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"]
"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"]
"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"]
"tests/kernels/test_onednn.py" = ["SIM108"]
"tests/kernels/utils.py" = ["SIM108"]
"tests/multimodal/test_processing.py" = ["SIM108"]
"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"]
"vllm/distributed/parallel_state.py" = ["SIM108"]
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
"vllm/entrypoints/llm.py" = ["SIM108"]
"vllm/executor/ray_distributed_executor.py" = ["SIM108", "SIM112"]
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"]
"vllm/model_executor/layers/layernorm.py" = ["SIM108"]
"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"]
"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"]
"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"]
"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"]
"vllm/utils/__init__.py" = ["SIM108"]
"vllm/v1/sample/ops/bad_words.py" = ["SIM108"]
"vllm/v1/sample/rejection_sampler.py" = ["SIM108"]
"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"]
"vllm/_custom_ops.py" = ["SIM108"]
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
## Loop variable binding issues
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
# End of temporary ignores # End of temporary ignores
[tool.ruff.lint] [tool.ruff.lint]
......
...@@ -12,10 +12,7 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts) ...@@ -12,10 +12,7 @@ def verify_round_robin_pattern(expert_map, ep_rank, ep_size, global_num_experts)
base_experts = global_num_experts // ep_size base_experts = global_num_experts // ep_size
remainder = global_num_experts % ep_size remainder = global_num_experts % ep_size
if ep_rank < remainder: local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
local_num_experts = base_experts + 1
else:
local_num_experts = base_experts
# Expected expert IDs for this rank in round_robin pattern # Expected expert IDs for this rank in round_robin pattern
# For non-divisible cases, ranks with extra experts start earlier # For non-divisible cases, ranks with extra experts start earlier
......
...@@ -66,10 +66,7 @@ def test_cutlass_mla_decode( ...@@ -66,10 +66,7 @@ def test_cutlass_mla_decode(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
): ):
device = torch.device("cuda:0") device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn: init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
torch.set_default_dtype(init_dtype) torch.set_default_dtype(init_dtype)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
......
...@@ -52,10 +52,7 @@ def test_flash_mla( ...@@ -52,10 +52,7 @@ def test_flash_mla(
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
): ):
device = torch.device("cuda:0") device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn: init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
torch.set_default_dtype(init_dtype) torch.set_default_dtype(init_dtype)
torch.set_default_device(device) torch.set_default_device(device)
torch.cuda.set_device(device) torch.cuda.set_device(device)
......
...@@ -33,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history): ...@@ -33,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
# More efficient implementation # More efficient implementation
# Convert decay factors to matrix form # Convert decay factors to matrix form
if ed.dim() == 1: decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
decay = torch.exp(-ed).view(1, -1, 1, 1)
else:
decay = torch.exp(-ed)
for b in range(B): for b in range(B):
for step in range(S): for step in range(S):
......
...@@ -705,10 +705,7 @@ def _pplx_moe( ...@@ -705,10 +705,7 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config): with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
if shared_experts is not None: shared_output = shared_experts(a) if shared_experts is not None else None
shared_output = shared_experts(a)
else:
shared_output = None
torch_output = torch_experts( torch_output = torch_experts(
a, a,
......
...@@ -88,10 +88,7 @@ def cutlass_fp8_gemm_helper( ...@@ -88,10 +88,7 @@ def cutlass_fp8_gemm_helper(
# make scales K-major for blockwise quant, doesn't affect 1D scales # make scales K-major for blockwise quant, doesn't affect 1D scales
scale_b = scale_b.t().contiguous().t() scale_b = scale_b.t().contiguous().t()
if use_bias: bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
...@@ -122,10 +119,7 @@ def cutlass_int8_gemm_helper( ...@@ -122,10 +119,7 @@ def cutlass_int8_gemm_helper(
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32) scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32) scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
if use_bias: bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
......
...@@ -84,10 +84,7 @@ def onednn_int8_gemm_test_helper( ...@@ -84,10 +84,7 @@ def onednn_int8_gemm_test_helper(
azp = None azp = None
azp_adj = None azp_adj = None
if use_bias: bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
handler = ops.create_onednn_scaled_mm( handler = ops.create_onednn_scaled_mm(
b, b,
......
...@@ -963,13 +963,9 @@ def make_test_metadata( ...@@ -963,13 +963,9 @@ def make_test_metadata(
None if encoder_seq_lens is None else (sum(encoder_seq_lens)) None if encoder_seq_lens is None else (sum(encoder_seq_lens))
) )
if cross_test_params is None: # For encoder/decoder or encoder-only models only, extract *cross-attention*
cross_kv_mmap = None # slot_mapping and block table (kv_mmap)
else: cross_kv_mmap = None if cross_test_params is None else cross_test_params.kv_mmap
# Encoder/decoder or encoder-only models only:
# * Extract *cross-attention* slot_mapping and block table
# (kv_mmap)
cross_kv_mmap = cross_test_params.kv_mmap
attn_backend_obj = make_backend(attn_backend.name) attn_backend_obj = make_backend(attn_backend.name)
......
...@@ -941,10 +941,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ...@@ -941,10 +941,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
if is_valid: exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="At most")
with exc_ctx: with exc_ctx:
profiler.get_decoder_dummy_data( profiler.get_decoder_dummy_data(
...@@ -985,10 +982,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ...@@ -985,10 +982,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
else: else:
mm_data = {"image": [image] * num_images} mm_data = {"image": [image] * num_images}
if is_valid: exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="At most")
with exc_ctx: with exc_ctx:
processor.apply( processor.apply(
......
...@@ -58,7 +58,7 @@ if __name__ == "__main__": ...@@ -58,7 +58,7 @@ if __name__ == "__main__":
assert args.phase in profile_data, ( assert args.phase in profile_data, (
f"Cannot find phase {args.phase} in profile data. Choose one among" f"Cannot find phase {args.phase} in profile data. Choose one among"
f"{[x for x in profile_data.keys() if 'prefill' in x or 'decode' in x]}" f"{[x for x in profile_data if 'prefill' in x or 'decode' in x]}"
) # noqa ) # noqa
if args.table == "summary": if args.table == "summary":
......
...@@ -2370,10 +2370,7 @@ class CPUDNNLGEMMHandler: ...@@ -2370,10 +2370,7 @@ class CPUDNNLGEMMHandler:
torch.ops._C.release_dnnl_matmul_handler(self.handler) torch.ops._C.release_dnnl_matmul_handler(self.handler)
if hasattr(torch.ops._C, "create_onednn_mm_handler"): _supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler"))
_supports_onednn = True
else:
_supports_onednn = False
def is_onednn_acl_supported(): def is_onednn_acl_supported():
......
...@@ -52,12 +52,9 @@ def reshape_and_cache_kernel_flash( ...@@ -52,12 +52,9 @@ def reshape_and_cache_kernel_flash(
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size) key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
) )
if FP8_KV_CACHE: if FP8_KV_CACHE:
if key_load.dtype.is_fp8():
key_tile = key_load
else:
# tl.store will do the correct implicit cast to fp8, # tl.store will do the correct implicit cast to fp8,
# based on the key_cache_ptr.dtype.element_ty # based on the key_cache_ptr.dtype.element_ty
key_tile = key_load / tl.load(k_scale) key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
else: else:
key_tile = key_load key_tile = key_load
......
...@@ -1097,10 +1097,7 @@ def init_distributed_environment( ...@@ -1097,10 +1097,7 @@ def init_distributed_environment(
if local_rank == -1: if local_rank == -1:
# local rank not set, this usually happens in single-node # local rank not set, this usually happens in single-node
# setting, where we can use rank as local rank # setting, where we can use rank as local rank
if distributed_init_method == "env://": local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
local_rank = envs.LOCAL_RANK
else:
local_rank = rank
global _WORLD, _NODE_COUNT global _WORLD, _NODE_COUNT
if _WORLD is None: if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size())) ranks = list(range(torch.distributed.get_world_size()))
......
...@@ -1310,10 +1310,7 @@ def _parse_chat_message_content_part( ...@@ -1310,10 +1310,7 @@ def _parse_chat_message_content_part(
modality = None modality = None
if part_type == "image_pil": if part_type == "image_pil":
if content is not None: image_content = cast(Image.Image, content) if content is not None else None
image_content = cast(Image.Image, content)
else:
image_content = None
mm_parser.parse_image_pil(image_content, uuid) mm_parser.parse_image_pil(image_content, uuid)
modality = "image" modality = "image"
elif part_type in ("image_url", "input_image"): elif part_type in ("image_url", "input_image"):
......
...@@ -1018,10 +1018,7 @@ class LLM: ...@@ -1018,10 +1018,7 @@ class LLM:
pooling_task = "encode" pooling_task = "encode"
if pooling_task is None: if pooling_task is None:
if "embed" in self.supported_tasks: pooling_task = "embed" if "embed" in self.supported_tasks else "encode"
pooling_task = "embed"
else:
pooling_task = "encode"
logger.warning_once( logger.warning_once(
"`LLM.encode` is currently using `pooling_task = %s`.\n" "`LLM.encode` is currently using `pooling_task = %s`.\n"
......
...@@ -458,10 +458,7 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -458,10 +458,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
else: else:
serialized_data = self.input_encoder.encode(execute_model_req) serialized_data = self.input_encoder.encode(execute_model_req)
outputs = ray.get(self.forward_dag.execute(serialized_data)) outputs = ray.get(self.forward_dag.execute(serialized_data))
if self.use_v1: output = outputs[0] if self.use_v1 else self.output_decoder.decode(outputs[0])
output = outputs[0]
else:
output = self.output_decoder.decode(outputs[0])
return output return output
def _run_workers( def _run_workers(
...@@ -482,10 +479,7 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -482,10 +479,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
rather than blocking on the results. rather than blocking on the results.
- args/kwargs: All workers share the same args/kwargs - args/kwargs: All workers share the same args/kwargs
""" """
if isinstance(method, str): sent_method = method if isinstance(method, str) else cloudpickle.dumps(method)
sent_method = method
else:
sent_method = cloudpickle.dumps(method)
del method del method
if self.use_ray_spmd_worker: if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, ( assert not async_run_tensor_parallel_workers_only, (
...@@ -573,8 +567,9 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -573,8 +567,9 @@ class RayDistributedExecutor(DistributedExecutorBase):
from ray.dag import InputNode, MultiOutputNode from ray.dag import InputNode, MultiOutputNode
logger.info( logger.info(
"RAY_CGRAPH_get_timeout is set to %s", os.environ["RAY_CGRAPH_get_timeout"] "RAY_CGRAPH_get_timeout is set to %s",
) # noqa: SIM112 os.environ["RAY_CGRAPH_get_timeout"], # noqa: SIM112
)
logger.info( logger.info(
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s", "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE, envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE,
......
...@@ -439,10 +439,7 @@ def mean_dim( ...@@ -439,10 +439,7 @@ def mean_dim(
output = torch.empty(output_shape, dtype=dtype, device=input.device) output = torch.empty(output_shape, dtype=dtype, device=input.device)
# Reshape output for kernel # Reshape output for kernel
if keepdim: output_2d = output.reshape(M, 1, K).squeeze(1) if keepdim else output.reshape(M, K)
output_2d = output.reshape(M, 1, K).squeeze(1)
else:
output_2d = output.reshape(M, K)
# Launch kernel # Launch kernel
grid = (M * K,) grid = (M * K,)
......
...@@ -151,10 +151,7 @@ def chunk_fwd_o( ...@@ -151,10 +151,7 @@ def chunk_fwd_o(
) -> torch.Tensor: ) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1] B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2] H = v.shape[-2]
if FLA_GDN_FIX_BT: BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))
BT = 64
else:
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = ( chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
) )
......
...@@ -1746,10 +1746,7 @@ def fused_experts_impl( ...@@ -1746,10 +1746,7 @@ def fused_experts_impl(
else: else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
if inplace: out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
if use_mxfp4_w4a4: if use_mxfp4_w4a4:
# Weight has to be dequantized for mxfp4 emulation. # Weight has to be dequantized for mxfp4 emulation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment