Unverified Commit b893d661 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix per file ruff ignores related to simplification (#26259)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6b6e9877
......@@ -886,10 +886,7 @@ def determine_expert_map(
# Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size
remainder = global_num_experts % ep_size
if ep_rank < remainder:
local_num_experts = base_experts + 1
else:
local_num_experts = base_experts
local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32)
......
......@@ -948,10 +948,7 @@ class FusedMoEModularKernel(torch.nn.Module):
"""
a1 = hidden_states
if inplace and self.shared_experts is None:
output = a1
else:
output = torch.zeros_like(a1)
output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
......
......@@ -355,10 +355,7 @@ def rocm_aiter_fused_experts(
topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32)
if expert_map is not None:
expert_mask = (expert_map > -1).to(torch.int32)
else:
expert_mask = None
expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None
# w8a8 per-channel quantization
if (
......
......@@ -318,10 +318,7 @@ class GemmaRMSNorm(CustomOp):
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
if residual is not None:
if orig_dtype == torch.float16:
x = x + residual.float()
else:
x = x + residual
x = x + residual.float() if orig_dtype == torch.float16 else x + residual
residual = x
x = x.float()
......
......@@ -207,10 +207,7 @@ def _fwd_kv_parallel(
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
# Handle the last block which might be smaller than BLOCK
if off_block == NUM_BLOCK - 1:
split_n = n - (NUM_BLOCK - 1) * BLOCK
else:
split_n = BLOCK
split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
......
......@@ -502,15 +502,11 @@ class CompressedTensorsConfig(QuantizationConfig):
QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK,
]
if not (
return (
is_symmetric_weight
and is_static_weight # noqa: SIM103
and is_static_weight
and is_tensor_or_channel_or_block_weight
):
return False
# All conditions satisfied.
return True
)
def _is_wNa16_group_channel(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
......
......@@ -80,10 +80,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
def find_matched_target(
......
......@@ -81,10 +81,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
def _is_equal_or_regex_match(
......
......@@ -3052,10 +3052,7 @@ def make_zmq_socket(
# - Set a large 0.5GB buffer to improve throughput
# For systems with less memory:
# - Use system default (-1) to avoid excessive memory consumption
if total_mem > 32 and available_mem > 16:
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
else:
buf_size = -1 # Use system default buffer size
buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
if bind is None:
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
......
......@@ -17,10 +17,7 @@ def _apply_bad_words_single_batch(
prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1]
if prefix_length > 0:
actual_prefix = past_tokens_ids[-prefix_length:]
else:
actual_prefix = []
actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else []
expected_prefix = bad_word_ids[:prefix_length]
assert len(actual_prefix) == len(expected_prefix)
......
......@@ -444,18 +444,12 @@ def rejection_greedy_sample_kernel(
req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
if is_greedy_ptr is None:
is_greedy = True
else:
is_greedy = tl.load(is_greedy_ptr + req_idx)
is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
if not is_greedy:
# Early exit for non-greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
......@@ -503,10 +497,7 @@ def rejection_random_sample_kernel(
# Early exit for greedy sampling requests.
return
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
......@@ -583,10 +574,7 @@ def sample_recovered_tokens_kernel(
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
......
......@@ -507,12 +507,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
# Fill the empty index or append to the end
req_index = removed_req_indices.pop() if removed_req_indices else None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment