Unverified Commit bab02ff2 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
parent 4b49c50f
...@@ -2,23 +2,15 @@ import re ...@@ -2,23 +2,15 @@ import re
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Type
from transformers import ( from transformers import (
AutoTokenizer,
AutoConfig,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models import CausalLM
from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
NextTokenChooser, NextTokenChooser,
StoppingCriteria, StoppingCriteria,
initialize_torch_distributed,
weight_files,
Weights,
) )
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
......
import torch import torch
import torch.distributed import torch.distributed
from typing import List, Optional, Tuple from typing import Optional
from transformers import (
AutoTokenizer,
AutoConfig,
AutoProcessor,
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import ( from text_generation_server.models.custom_modeling.idefics_processing import (
......
...@@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -289,7 +289,7 @@ class IdeficsCausalLMBatch(Batch):
image_hidden_states = self.image_hidden_states[keep_indices] image_hidden_states = self.image_hidden_states[keep_indices]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values] self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection
...@@ -456,7 +456,7 @@ class IdeficsCausalLMBatch(Batch): ...@@ -456,7 +456,7 @@ class IdeficsCausalLMBatch(Batch):
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place # And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [ batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values for layer in batch.past_key_values
......
...@@ -2,7 +2,6 @@ import torch ...@@ -2,7 +2,6 @@ import torch
import torch.distributed import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional from typing import Optional
import os
from text_generation_server.models.custom_modeling.mamba_modeling import ( from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig, MambaConfig,
) )
...@@ -20,7 +19,7 @@ from text_generation_server.models.custom_modeling.mamba_modeling import ( ...@@ -20,7 +19,7 @@ from text_generation_server.models.custom_modeling.mamba_modeling import (
InferenceParams, InferenceParams,
) )
from text_generation_server.models import Model from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict from typing import Any, List, Tuple, Type, Dict
from text_generation_server.models.types import ( from text_generation_server.models.types import (
Batch, Batch,
Tokens, Tokens,
...@@ -31,7 +30,7 @@ from text_generation_server.utils.chunks import concat_text_chunks ...@@ -31,7 +30,7 @@ from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria
def new_inference_params( def new_inference_params(
...@@ -299,7 +298,6 @@ class MambaBatch(Batch): ...@@ -299,7 +298,6 @@ class MambaBatch(Batch):
stopping_criterias = [] stopping_criterias = []
top_n_tokens = [] top_n_tokens = []
max_tokens = 0 max_tokens = 0
max_seqlen = 0
seqlen_offset = 0 seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape (n_blocks, _, d_inner, d_conv) = batches[0].inference_params.conv_states.shape
...@@ -485,7 +483,7 @@ class Mamba(Model): ...@@ -485,7 +483,7 @@ class Mamba(Model):
for bs in CUDA_GRAPHS: for bs in CUDA_GRAPHS:
self.cuda_graph_warmup(bs) self.cuda_graph_warmup(bs)
except Exception: except Exception:
logger.exception(f"Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:
logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).") logger.info(f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS}).")
...@@ -534,7 +532,7 @@ class Mamba(Model): ...@@ -534,7 +532,7 @@ class Mamba(Model):
} }
self.cuda_graphs[batch_size] = graph_dict self.cuda_graphs[batch_size] = graph_dict
def tunableop_warmup(self, seqlen: int): def tunableop_warmup(self, batch_size: int, seqlen: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device) input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks) n_blocks = len(self.model.blocks)
......
...@@ -2,7 +2,7 @@ import inspect ...@@ -2,7 +2,7 @@ import inspect
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type, Dict, DefaultDict from typing import List, Tuple, Optional, TypeVar, Type, Dict
from collections import defaultdict from collections import defaultdict
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
......
...@@ -3,16 +3,11 @@ from PIL import Image ...@@ -3,16 +3,11 @@ from PIL import Image
import torch import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable, Optional, Tuple from typing import Iterable
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLM,
VlmCausalLMBatch, VlmCausalLMBatch,
image_text_replacement, image_text_replacement,
) )
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from transformers import AutoProcessor, AutoConfig
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
......
import torch import torch
import torch.distributed import torch.distributed
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
...@@ -11,7 +10,7 @@ from transformers import ( ...@@ -11,7 +10,7 @@ from transformers import (
AutoConfig, AutoConfig,
) )
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
...@@ -254,7 +253,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -254,7 +253,7 @@ class Seq2SeqLMBatch(Batch):
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) is tuple:
self.past_key_values = [ self.past_key_values = [
[t for t in layer] for layer in self.past_key_values [t for t in layer] for layer in self.past_key_values
] ]
...@@ -430,7 +429,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -430,7 +429,7 @@ class Seq2SeqLMBatch(Batch):
batch.encoder_last_hidden_state = None batch.encoder_last_hidden_state = None
# Ensure that we can update tensors in-place # Ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [ batch.past_key_values = [
[t for t in layer] for layer in batch.past_key_values [t for t in layer] for layer in batch.past_key_values
] ]
......
from functools import total_ordering
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
......
...@@ -9,7 +9,7 @@ from loguru import logger ...@@ -9,7 +9,7 @@ from loguru import logger
from grpc_reflection.v1alpha import reflection from grpc_reflection.v1alpha import reflection
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict from typing import List, Optional
from text_generation_server.cache import Cache from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor from text_generation_server.interceptor import ExceptionInterceptor
......
import torch import torch
from loguru import logger from loguru import logger
import subprocess
import os import os
import importlib.util
def is_ipex_available(): def is_ipex_available():
try: return importlib.util.find_spec("intel_extension_for_pytorch") is not None
import intel_extension_for_pytorch
except ImportError:
return False
return True
def get_cuda_free_memory(device, memory_fraction): def get_cuda_free_memory(device, memory_fraction):
......
...@@ -2,9 +2,17 @@ import copy ...@@ -2,9 +2,17 @@ import copy
from abc import ABC from abc import ABC
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union from typing import TYPE_CHECKING, Dict, List, Tuple, Type, Union
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
import torch import torch
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
class AdapterParameters: class AdapterParameters:
def __init__( def __init__(
...@@ -17,17 +25,6 @@ class AdapterParameters: ...@@ -17,17 +25,6 @@ class AdapterParameters:
self.majority_sign_method = majority_sign_method self.majority_sign_method = majority_sign_method
from text_generation_server.utils.merges.utils import (
calculate_majority_sign_mask,
disjoint_merge,
prune,
)
if TYPE_CHECKING:
from text_generation_server.adapters.lora import LoraConfig
from text_generation_server.utils.adapter import ModuleMap
def _apply_weights( def _apply_weights(
tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor tensors: Union[torch.Tensor, List[torch.Tensor]], w: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -28,7 +28,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): ...@@ -28,7 +28,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
) )
logger.info("Peft model detected.") logger.info("Peft model detected.")
logger.info(f"Merging the lora weights.") logger.info("Merging the lora weights.")
base_model_id = model.peft_config["default"].base_model_name_or_path base_model_id = model.peft_config["default"].base_model_name_or_path
......
...@@ -6,7 +6,6 @@ from typing import Optional ...@@ -6,7 +6,6 @@ from typing import Optional
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
UnquantizedWeight,
WeightsLoader, WeightsLoader,
) )
......
import re import re
from typing import List, Optional, Tuple, Set, Union from typing import List, Optional, Tuple, Set, Union
import math
import torch import torch
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType from text_generation_server.pb.generate_pb2 import FinishReason, GrammarType
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment