Unverified Commit bab02ff2 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
parent 4b49c50f
......@@ -16,3 +16,8 @@ repos:
- id: fmt
- id: cargo-check
- id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
......@@ -19,5 +19,15 @@ DEPRECATION_WARNING = (
"Please use the `InferenceClient` from the `huggingface_hub` package instead."
)
from text_generation.client import Client, AsyncClient
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient
from text_generation.client import Client, AsyncClient # noqa E402
from text_generation.inference_api import ( # noqa E402
InferenceAPIClient,
InferenceAPIAsyncClient,
)
__all__ = [
"Client",
"AsyncClient",
"InferenceAPIClient",
"InferenceAPIAsyncClient",
]
......@@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
List[DeployedModel]: list of all currently deployed models
"""
resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference",
"https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers,
timeout=5,
)
......
......@@ -4,7 +4,6 @@ import json
import math
import os
import random
import re
import shutil
import subprocess
import sys
......@@ -271,7 +270,7 @@ class LauncherHandle:
try:
await self.client.generate("test")
return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e:
except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1)
raise RuntimeError("Health check failed")
......
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
......
import pytest
import requests
import io
import base64
......
......@@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses]
assert (
generated_texts[0] == " \nAssistant: A rooster stands"
), f"{response.generated_text}"
assert generated_texts[0] == " \nAssistant: A rooster stands"
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
......
import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module")
......@@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
......@@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto(
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
......@@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice(
},
],
)
assert response.choices[0].message.content == None
assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [
{
"id": 0,
......@@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
stream=False,
)
assert responses.choices[0].message.content == None
assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [
{
"function": {
......
......@@ -20,7 +20,7 @@ def main():
break
with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4)
json.dump(conversations, f, indent=4)
if __name__ == "__main__":
......
import os
import requests
import tempfile
import pytest
import huggingface_hub.constants
from huggingface_hub import hf_api
import text_generation_server.utils.hub
from text_generation_server.utils.hub import (
......
......@@ -2,7 +2,6 @@ import torch
from text_generation_server.layers import (
TensorParallelEmbedding,
)
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup:
......
......@@ -2,7 +2,6 @@ import pytest
import torch
from text_generation_server.utils.weights import (
DefaultWeightsLoader,
UnquantizedWeight,
Weights,
WeightsLoader,
)
......@@ -86,15 +85,6 @@ dummy_file_system = {
],
dtype=torch.float32,
),
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
},
"test_get_weights_row": {
"weight.weight": torch.tensor(
......@@ -966,7 +956,7 @@ def test_get_multi_weights_col_exl2():
prefix = "weight"
try:
w = weights.get_multi_weights_col(
weights.get_multi_weights_col(
prefixes=[prefix],
dim=0,
)
......
......@@ -4,15 +4,12 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Set, Tuple
from typing import Dict, Set, Tuple
import torch
from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass
class ModuleMap:
......
......@@ -4,7 +4,7 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch
from peft import LoraConfig as _LoraConfig
......@@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import (
use_cutlass_shrink,
)
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size
......
......@@ -4,12 +4,11 @@ import typer
from pathlib import Path
from loguru import logger
from typing import Optional, List, Dict
from typing import Optional
from enum import Enum
from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters
from text_generation_server.utils.log import log_master
app = typer.Typer()
......@@ -165,7 +164,7 @@ def download_weights(
# currently by default we don't merge the weights with the base model
if merge_lora:
try:
adapter_config_filename = hf_hub_download(
hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
......@@ -285,9 +284,9 @@ def download_weights(
if auto_convert:
if not trust_remote_code:
logger.warning(
f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
f"Pickle files are unsafe and can essentially contain remote code execution!"
f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
"Pickle files are unsafe and can essentially contain remote code execution!"
"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
)
logger.warning(
......@@ -319,7 +318,7 @@ def download_weights(
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])
except Exception as e:
except Exception:
discard_names = []
# Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names)
......
......@@ -18,3 +18,17 @@ from text_generation_server.layers.lora import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
__all__ = [
"get_linear",
"FastLinear",
"TensorParallelColumnLinear",
"TensorParallelRowLinear",
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",
"load_conv2d",
]
......@@ -13,3 +13,12 @@ elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"Seqlen",
]
......@@ -10,7 +10,6 @@ _PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
......
......@@ -747,11 +747,8 @@ class _attention(torch.autograd.Function):
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
grid = lambda META: (
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
nheads_q,
batch,
)
def grid(META):
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
encoded_softmax = None
......
......@@ -15,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck"
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment