Unverified Commit 9d104b5b authored by Aaron Pham's avatar Aaron Pham Committed by GitHub
Browse files

[CI/Build] Update Ruff version (#8469)


Signed-off-by: default avatarAaron Pham <contact@aarnphm.xyz>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 6ffa3f31
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, cast
import torch import torch
from pydantic import BaseModel from pydantic import BaseModel
...@@ -79,8 +79,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -79,8 +79,8 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
target_scheme_map: Dict[str, Any] = dict() target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None) ignore = cast(List[str], config.get("ignore"))
quant_format: str = config.get("format", None) quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing # The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are # an input_activations key with details about how the activations are
...@@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -200,7 +200,7 @@ class CompressedTensorsConfig(QuantizationConfig):
is_per_tensor_or_channel_weight = (weight_quant.strategy in [ is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
]) ])
if not (is_symmetric_weight and is_static_weight if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight): and is_per_tensor_or_channel_weight):
return False return False
......
...@@ -132,10 +132,10 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -132,10 +132,10 @@ class GPTQMarlinConfig(QuantizationConfig):
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None) num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size", None) group_size = quant_config.get("group_size")
sym = quant_config.get("sym", None) sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act", None) desc_act = quant_config.get("desc_act")
if quant_method != "gptq": if quant_method != "gptq":
return False return False
......
...@@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: ...@@ -408,9 +408,7 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"inferred as vLLM models, so setting vllm_tensorized=True is " "inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change.") "only necessary for models serialized prior to this change.")
return True return True
if (".vllm_tensorized_marker" in deserializer): return ".vllm_tensorized_marker" in deserializer
return True
return False
def serialize_vllm_model( def serialize_vllm_model(
......
...@@ -884,7 +884,7 @@ class MiniCPMV(MiniCPMVBaseModel): ...@@ -884,7 +884,7 @@ class MiniCPMV(MiniCPMVBaseModel):
version = str(config.version).split(".") version = str(config.version).split(".")
version = tuple([int(x) for x in version]) version = tuple([int(x) for x in version])
# Dispatch class based on version # Dispatch class based on version
instance_class = _SUPPORT_VERSION.get(version, None) instance_class = _SUPPORT_VERSION.get(version)
if instance_class is None: if instance_class is None:
raise ValueError( raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
......
...@@ -183,10 +183,7 @@ class TP1DraftModelRunner(ModelRunner): ...@@ -183,10 +183,7 @@ class TP1DraftModelRunner(ModelRunner):
return False return False
# TODO: Add soft-tuning prompt adapter support # TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config: return not self.prompt_adapter_config
return False
return True
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
......
...@@ -104,10 +104,7 @@ class AsyncMetricsCollector: ...@@ -104,10 +104,7 @@ class AsyncMetricsCollector:
if self._rank != 0: if self._rank != 0:
return False return False
if (now - self._last_metrics_collect_time < return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501
self._rejsample_metrics_collect_interval_s):
return False
return True
def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics """Copy rejection/typical-acceptance sampling metrics
......
...@@ -35,8 +35,8 @@ class LibEntry(triton.KernelInterface): ...@@ -35,8 +35,8 @@ class LibEntry(triton.KernelInterface):
dns_key = [ dns_key = [
arg.dtype if hasattr( arg.dtype if hasattr(
arg, "data_ptr") else type(arg) if not isinstance(arg, int) arg, "data_ptr") else type(arg) if not isinstance(arg, int)
else "i32" if -(2**31) <= arg and arg <= 2**31 - else "i32" if arg >= -(2**31) and arg <= 2**31 -
1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64" 1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
for arg in dns_args for arg in dns_args
] ]
# const args passed by position # const args passed by position
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment