Unverified Commit bab02ff2 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
parent 4b49c50f
...@@ -16,3 +16,8 @@ repos: ...@@ -16,3 +16,8 @@ repos:
- id: fmt - id: fmt
- id: cargo-check - id: cargo-check
- id: clippy - id: clippy
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
...@@ -19,5 +19,15 @@ DEPRECATION_WARNING = ( ...@@ -19,5 +19,15 @@ DEPRECATION_WARNING = (
"Please use the `InferenceClient` from the `huggingface_hub` package instead." "Please use the `InferenceClient` from the `huggingface_hub` package instead."
) )
from text_generation.client import Client, AsyncClient from text_generation.client import Client, AsyncClient # noqa E402
from text_generation.inference_api import InferenceAPIClient, InferenceAPIAsyncClient from text_generation.inference_api import ( # noqa E402
InferenceAPIClient,
InferenceAPIAsyncClient,
)
__all__ = [
"Client",
"AsyncClient",
"InferenceAPIClient",
"InferenceAPIAsyncClient",
]
...@@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]: ...@@ -21,7 +21,7 @@ def deployed_models(headers: Optional[Dict] = None) -> List[DeployedModel]:
List[DeployedModel]: list of all currently deployed models List[DeployedModel]: list of all currently deployed models
""" """
resp = requests.get( resp = requests.get(
f"https://api-inference.huggingface.co/framework/text-generation-inference", "https://api-inference.huggingface.co/framework/text-generation-inference",
headers=headers, headers=headers,
timeout=5, timeout=5,
) )
......
...@@ -4,7 +4,6 @@ import json ...@@ -4,7 +4,6 @@ import json
import math import math
import os import os
import random import random
import re
import shutil import shutil
import subprocess import subprocess
import sys import sys
...@@ -271,7 +270,7 @@ class LauncherHandle: ...@@ -271,7 +270,7 @@ class LauncherHandle:
try: try:
await self.client.generate("test") await self.client.generate("test")
return return
except (ClientConnectorError, ClientOSError, ServerDisconnectedError) as e: except (ClientConnectorError, ClientOSError, ServerDisconnectedError):
time.sleep(1) time.sleep(1)
raise RuntimeError("Health check failed") raise RuntimeError("Health check failed")
......
import pytest import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
......
import pytest import pytest
import requests
import io
import base64 import base64
......
...@@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot): ...@@ -74,9 +74,7 @@ async def test_idefics_load(idefics, generate_load, response_snapshot):
generated_texts = [r.generated_text for r in responses] generated_texts = [r.generated_text for r in responses]
assert ( assert generated_texts[0] == " \nAssistant: A rooster stands"
generated_texts[0] == " \nAssistant: A rooster stands"
), f"{response.generated_text}"
assert len(generated_texts) == 4 assert len(generated_texts) == 4
assert generated_texts, all( assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts] [text == generated_texts[0] for text in generated_texts]
......
import pytest import pytest
import json
from text_generation.types import GrammarType
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
...@@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna ...@@ -91,7 +88,7 @@ async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_sna
}, },
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": 0,
...@@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto( ...@@ -129,7 +126,7 @@ async def test_flash_llama_grammar_tools_auto(
}, },
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": 0,
...@@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice( ...@@ -168,7 +165,7 @@ async def test_flash_llama_grammar_tools_choice(
}, },
], ],
) )
assert response.choices[0].message.content == None assert response.choices[0].message.content is None
assert response.choices[0].message.tool_calls == [ assert response.choices[0].message.tool_calls == [
{ {
"id": 0, "id": 0,
...@@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information( ...@@ -241,7 +238,7 @@ async def test_flash_llama_grammar_tools_insufficient_information(
stream=False, stream=False,
) )
assert responses.choices[0].message.content == None assert responses.choices[0].message.content is None
assert responses.choices[0].message.tool_calls == [ assert responses.choices[0].message.tool_calls == [
{ {
"function": { "function": {
......
...@@ -20,7 +20,7 @@ def main(): ...@@ -20,7 +20,7 @@ def main():
break break
with open("./small.json", "w") as f: with open("./small.json", "w") as f:
data = json.dump(conversations, f, indent=4) json.dump(conversations, f, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
......
import os import os
import requests
import tempfile import tempfile
import pytest import pytest
import huggingface_hub.constants import huggingface_hub.constants
from huggingface_hub import hf_api
import text_generation_server.utils.hub import text_generation_server.utils.hub
from text_generation_server.utils.hub import ( from text_generation_server.utils.hub import (
......
...@@ -2,7 +2,6 @@ import torch ...@@ -2,7 +2,6 @@ import torch
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
class ProcessGroup: class ProcessGroup:
......
...@@ -2,7 +2,6 @@ import pytest ...@@ -2,7 +2,6 @@ import pytest
import torch import torch
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight,
Weights, Weights,
WeightsLoader, WeightsLoader,
) )
...@@ -86,15 +85,6 @@ dummy_file_system = { ...@@ -86,15 +85,6 @@ dummy_file_system = {
], ],
dtype=torch.float32, dtype=torch.float32,
), ),
"weight.weight": torch.tensor(
[
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
dtype=torch.float32,
),
}, },
"test_get_weights_row": { "test_get_weights_row": {
"weight.weight": torch.tensor( "weight.weight": torch.tensor(
...@@ -966,7 +956,7 @@ def test_get_multi_weights_col_exl2(): ...@@ -966,7 +956,7 @@ def test_get_multi_weights_col_exl2():
prefix = "weight" prefix = "weight"
try: try:
w = weights.get_multi_weights_col( weights.get_multi_weights_col(
prefixes=[prefix], prefixes=[prefix],
dim=0, dim=0,
) )
......
...@@ -4,15 +4,12 @@ ...@@ -4,15 +4,12 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Set, Tuple from typing import Dict, Set, Tuple
import torch import torch
from text_generation_server.adapters.weights import AdapterWeights from text_generation_server.adapters.weights import AdapterWeights
if TYPE_CHECKING:
from text_generation_server.models.model import Model
@dataclass @dataclass
class ModuleMap: class ModuleMap:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Type, Union from typing import Dict, List, Optional, Set, Tuple, Type, Union
import torch import torch
from peft import LoraConfig as _LoraConfig from peft import LoraConfig as _LoraConfig
...@@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import ( ...@@ -26,9 +26,6 @@ from text_generation_server.utils.sgmv import (
use_cutlass_shrink, use_cutlass_shrink,
) )
if TYPE_CHECKING:
from text_generation_server.models.model import Model
def get_start_stop_idxs_for_rank(offset, size, rank, world_size): def get_start_stop_idxs_for_rank(offset, size, rank, world_size):
block_size = size // world_size block_size = size // world_size
......
...@@ -4,12 +4,11 @@ import typer ...@@ -4,12 +4,11 @@ import typer
from pathlib import Path from pathlib import Path
from loguru import logger from loguru import logger
from typing import Optional, List, Dict from typing import Optional
from enum import Enum from enum import Enum
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.adapter import parse_lora_adapters from text_generation_server.utils.adapter import parse_lora_adapters
from text_generation_server.utils.log import log_master
app = typer.Typer() app = typer.Typer()
...@@ -165,7 +164,7 @@ def download_weights( ...@@ -165,7 +164,7 @@ def download_weights(
# currently by default we don't merge the weights with the base model # currently by default we don't merge the weights with the base model
if merge_lora: if merge_lora:
try: try:
adapter_config_filename = hf_hub_download( hf_hub_download(
model_id, revision=revision, filename="adapter_config.json" model_id, revision=revision, filename="adapter_config.json"
) )
utils.download_and_unload_peft( utils.download_and_unload_peft(
...@@ -285,9 +284,9 @@ def download_weights( ...@@ -285,9 +284,9 @@ def download_weights(
if auto_convert: if auto_convert:
if not trust_remote_code: if not trust_remote_code:
logger.warning( logger.warning(
f"🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because " "🚨🚨BREAKING CHANGE in 2.0🚨🚨: Safetensors conversion is disabled without `--trust-remote-code` because "
f"Pickle files are unsafe and can essentially contain remote code execution!" "Pickle files are unsafe and can essentially contain remote code execution!"
f"Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety", "Please check for more information here: https://huggingface.co/docs/text-generation-inference/basic_tutorials/safety",
) )
logger.warning( logger.warning(
...@@ -319,7 +318,7 @@ def download_weights( ...@@ -319,7 +318,7 @@ def download_weights(
# Name for this varible depends on transformers version. # Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", []) discard_names = getattr(class_, "_tied_weights_keys", [])
except Exception as e: except Exception:
discard_names = [] discard_names = []
# Convert pytorch weights to safetensors # Convert pytorch weights to safetensors
utils.convert_files(local_pt_files, local_st_files, discard_names) utils.convert_files(local_pt_files, local_st_files, discard_names)
......
...@@ -18,3 +18,17 @@ from text_generation_server.layers.lora import ( ...@@ -18,3 +18,17 @@ from text_generation_server.layers.lora import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
__all__ = [
"get_linear",
"FastLinear",
"TensorParallelColumnLinear",
"TensorParallelRowLinear",
"TensorParallelEmbedding",
"SpeculativeHead",
"LoraLinear",
"TensorParallelMultiAdapterLinear",
"TensorParallelAdapterRowLinear",
"load_layer_norm",
"load_conv2d",
]
...@@ -13,3 +13,12 @@ elif SYSTEM == "ipex": ...@@ -13,3 +13,12 @@ elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
__all__ = [
"attention",
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"Seqlen",
]
...@@ -10,7 +10,6 @@ _PARTITION_SIZE = 512 ...@@ -10,7 +10,6 @@ _PARTITION_SIZE = 512
try: try:
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm._C import ops
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
......
...@@ -747,11 +747,8 @@ class _attention(torch.autograd.Function): ...@@ -747,11 +747,8 @@ class _attention(torch.autograd.Function):
padded_d_model = 1 << (head_size - 1).bit_length() padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16) padded_d_model = max(padded_d_model, 16)
grid = lambda META: ( def grid(META):
triton.cdiv(max_seqlens_q, META["BLOCK_M"]), return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
nheads_q,
batch,
)
encoded_softmax = None encoded_softmax = None
......
...@@ -15,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck" ...@@ -15,7 +15,6 @@ ENGINE = "triton" if use_triton else "ck"
try: try:
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm._C import ops
except Exception as e: except Exception as e:
raise ImportError( raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}" f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment