Unverified Commit b893d661 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix per file ruff ignores related to simplification (#26259)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 6b6e9877
...@@ -886,10 +886,7 @@ def determine_expert_map( ...@@ -886,10 +886,7 @@ def determine_expert_map(
# Distribute experts as evenly as possible to each rank. # Distribute experts as evenly as possible to each rank.
base_experts = global_num_experts // ep_size base_experts = global_num_experts // ep_size
remainder = global_num_experts % ep_size remainder = global_num_experts % ep_size
if ep_rank < remainder: local_num_experts = base_experts + 1 if ep_rank < remainder else base_experts
local_num_experts = base_experts + 1
else:
local_num_experts = base_experts
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32) expert_map = torch.full((global_num_experts,), -1, dtype=torch.int32)
......
...@@ -948,10 +948,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -948,10 +948,7 @@ class FusedMoEModularKernel(torch.nn.Module):
""" """
a1 = hidden_states a1 = hidden_states
if inplace and self.shared_experts is None: output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1)
output = a1
else:
output = torch.zeros_like(a1)
local_num_experts = w1.size(0) local_num_experts = w1.size(0)
if global_num_experts == -1: if global_num_experts == -1:
......
...@@ -355,10 +355,7 @@ def rocm_aiter_fused_experts( ...@@ -355,10 +355,7 @@ def rocm_aiter_fused_experts(
topk_weights = topk_weights.to(torch.float32) topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
if expert_map is not None: expert_mask = (expert_map > -1).to(torch.int32) if expert_map is not None else None
expert_mask = (expert_map > -1).to(torch.int32)
else:
expert_mask = None
# w8a8 per-channel quantization # w8a8 per-channel quantization
if ( if (
......
...@@ -318,10 +318,7 @@ class GemmaRMSNorm(CustomOp): ...@@ -318,10 +318,7 @@ class GemmaRMSNorm(CustomOp):
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype orig_dtype = x.dtype
if residual is not None: if residual is not None:
if orig_dtype == torch.float16: x = x + residual.float() if orig_dtype == torch.float16 else x + residual
x = x + residual.float()
else:
x = x + residual
residual = x residual = x
x = x.float() x = x.float()
......
...@@ -207,10 +207,7 @@ def _fwd_kv_parallel( ...@@ -207,10 +207,7 @@ def _fwd_kv_parallel(
kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32) kv = tl.zeros([D_FBLOCK, E_FBLOCK], dtype=tl.float32)
# Handle the last block which might be smaller than BLOCK # Handle the last block which might be smaller than BLOCK
if off_block == NUM_BLOCK - 1: split_n = n - (NUM_BLOCK - 1) * BLOCK if off_block == NUM_BLOCK - 1 else BLOCK
split_n = n - (NUM_BLOCK - 1) * BLOCK
else:
split_n = BLOCK
left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n left_shift = tl.cdiv(split_n, CBLOCK) * CBLOCK - split_n
num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK) num_blocks = min(tl.cdiv(split_n, CBLOCK), NUM_CBLOCK)
k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK k_decay_ptr += (NUM_CBLOCK - num_blocks) * CBLOCK
......
...@@ -502,15 +502,11 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -502,15 +502,11 @@ class CompressedTensorsConfig(QuantizationConfig):
QuantizationStrategy.CHANNEL, QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK, QuantizationStrategy.BLOCK,
] ]
if not ( return (
is_symmetric_weight is_symmetric_weight
and is_static_weight # noqa: SIM103 and is_static_weight
and is_tensor_or_channel_or_block_weight and is_tensor_or_channel_or_block_weight
): )
return False
# All conditions satisfied.
return True
def _is_wNa16_group_channel( def _is_wNa16_group_channel(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
......
...@@ -80,10 +80,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: ...@@ -80,10 +80,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
Checks whether a layer_name is exactly equal or a regex match for Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list. if target starts with 're:' to any target in list.
""" """
for target in targets: return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target( def find_matched_target(
......
...@@ -81,10 +81,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool: ...@@ -81,10 +81,7 @@ def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
Checks whether a layer_name is exactly equal or a regex match for Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list. if target starts with 're:' to any target in list.
""" """
for target in targets: return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def _is_equal_or_regex_match( def _is_equal_or_regex_match(
......
...@@ -3052,10 +3052,7 @@ def make_zmq_socket( ...@@ -3052,10 +3052,7 @@ def make_zmq_socket(
# - Set a large 0.5GB buffer to improve throughput # - Set a large 0.5GB buffer to improve throughput
# For systems with less memory: # For systems with less memory:
# - Use system default (-1) to avoid excessive memory consumption # - Use system default (-1) to avoid excessive memory consumption
if total_mem > 32 and available_mem > 16: buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
else:
buf_size = -1 # Use system default buffer size
if bind is None: if bind is None:
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB) bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
......
...@@ -17,10 +17,7 @@ def _apply_bad_words_single_batch( ...@@ -17,10 +17,7 @@ def _apply_bad_words_single_batch(
prefix_length = len(bad_word_ids) - 1 prefix_length = len(bad_word_ids) - 1
last_token_id = bad_word_ids[-1] last_token_id = bad_word_ids[-1]
if prefix_length > 0: actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else []
actual_prefix = past_tokens_ids[-prefix_length:]
else:
actual_prefix = []
expected_prefix = bad_word_ids[:prefix_length] expected_prefix = bad_word_ids[:prefix_length]
assert len(actual_prefix) == len(expected_prefix) assert len(actual_prefix) == len(expected_prefix)
......
...@@ -444,18 +444,12 @@ def rejection_greedy_sample_kernel( ...@@ -444,18 +444,12 @@ def rejection_greedy_sample_kernel(
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run, # FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None. # re-compilation may happen during runtime when is_greedy_ptr is None.
if is_greedy_ptr is None: is_greedy = True if is_greedy_ptr is None else tl.load(is_greedy_ptr + req_idx)
is_greedy = True
else:
is_greedy = tl.load(is_greedy_ptr + req_idx)
if not is_greedy: if not is_greedy:
# Early exit for non-greedy sampling requests. # Early exit for non-greedy sampling requests.
return return
if req_idx == 0: start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx num_draft_tokens = end_idx - start_idx
...@@ -503,10 +497,7 @@ def rejection_random_sample_kernel( ...@@ -503,10 +497,7 @@ def rejection_random_sample_kernel(
# Early exit for greedy sampling requests. # Early exit for greedy sampling requests.
return return
if req_idx == 0: start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx num_draft_tokens = end_idx - start_idx
...@@ -583,10 +574,7 @@ def sample_recovered_tokens_kernel( ...@@ -583,10 +574,7 @@ def sample_recovered_tokens_kernel(
NO_DRAFT_PROBS: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
if req_idx == 0: start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx) end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx num_draft_tokens = end_idx - start_idx
......
...@@ -507,12 +507,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -507,12 +507,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
removed_req_indices = sorted(removed_req_indices, reverse=True) removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add: for req_id in req_ids_to_add:
req_state = self.requests[req_id] req_state = self.requests[req_id]
if removed_req_indices: # Fill the empty index or append to the end
# Fill the empty index. req_index = removed_req_indices.pop() if removed_req_indices else None
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index) self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices. # Condense the batched states if there are empty indices.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment