Unverified Commit 9d104b5b authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[CI/Build] Update Ruff version (#8469)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 6ffa3f31
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
import torch
from pydantic import BaseModel
......@@ -79,8 +79,8 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
......@@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight):
return False
......
......@@ -132,10 +132,10 @@ class GPTQMarlinConfig(QuantizationConfig):
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
if quant_method != "gptq":
return False
......
......@@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change.")
return True
if (".vllm_tensorized_marker" in deserializer):
return True
return False
return ".vllm_tensorized_marker" in deserializer
def serialize_vllm_model(
......
......@@ -884,7 +884,7 @@ class MiniCPMV(MiniCPMVBaseModel):
version = str(config.version).split(".")
version = tuple([int(x) for x in version])
# Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version, None)
instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None:
raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
......
......@@ -183,10 +183,7 @@ class TP1DraftModelRunner(ModelRunner):
return False
# TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config:
return False
return True
return not self.prompt_adapter_config
@torch.inference_mode()
def execute_model(
......
......@@ -104,10 +104,7 @@ class AsyncMetricsCollector:
if self._rank != 0:
return False
if (now - self._last_metrics_collect_time <
self._rejsample_metrics_collect_interval_s):
return False
return True
return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics
......
......@@ -35,8 +35,8 @@ class LibEntry(triton.KernelInterface):
dns_key = [
arg.dtype if hasattr(
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
else "i32" if -(2**31) <= arg and arg <= 2**31 -
1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64"
else "i32" if arg >= -(2**31) and arg <= 2**31 -
1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
for arg in dns_args
]
# const args passed by position
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment