Unverified Commit bab02ff2 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
parent 4b49c50f
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
import math
from typing import Optional
import torch
import torch.nn as nn
......
from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb
import torch
......
......@@ -12,17 +12,26 @@ from text_generation_server.utils.weights import (
Weights,
)
from text_generation_server.utils.log import log_master, log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError):
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
......
......@@ -8,6 +8,34 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "1"
except ImportError:
pass
@dataclass
class GPTQWeight(Weight):
......@@ -55,7 +83,7 @@ class GPTQWeight(Weight):
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
......@@ -73,45 +101,6 @@ class GPTQWeight(Weight):
)
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "1"
except ImportError:
pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader):
"""
Loader for GPTQ- and AWQ-quantized weights.
......
......@@ -206,10 +206,13 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)
grid = lambda META: (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
def grid(META):
return (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,
......
......@@ -15,6 +15,7 @@ from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger
from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
from text_generation_server.utils.weights import DefaultWeightsLoader
......@@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
......@@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
......@@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
......@@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
......@@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code
)
except:
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code
)
......@@ -700,6 +701,8 @@ def sequential(
pass
def add_batch(name):
nonlocal gptq
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
......
import torch
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
def torch_snr_error(
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calcualted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. "
f"({y_pred.shape} and {y_real.shape})"
)
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == "mean":
return torch.mean(snr)
elif reduction == "sum":
return torch.sum(snr)
elif reduction == "none":
return snr
else:
raise ValueError("Unsupported reduction method.")
from typing import Optional
import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F
......
import math
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
from typing import TYPE_CHECKING, Optional, List
import torch
import torch.distributed
from accelerate import init_empty_weights
from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import (
......
from typing import List, Tuple
import torch
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinLinear,
......
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
import torch.nn as nn
......@@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader):
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
......@@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader):
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized"
"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
......
......@@ -2,12 +2,9 @@ import os
import math
import torch
from torch import nn
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
elif SYSTEM == "rocm":
from vllm._C import ops
......
......@@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module):
except KeyError:
try:
speculator = MedusaHeadV1.load(config, prefix, weights)
except:
except Exception:
speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None
else:
......
......@@ -2,7 +2,6 @@ import torch
from torch.nn import functional as F
from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
......@@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer):
# If the piece and LM head embeddings are shared, we have
# non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight")
except:
except Exception:
# ...otherwise they are quantized.
weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1
......@@ -67,15 +66,6 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
quantize = None
else:
quantize = config.quantize
return TensorParallelHead(
get_linear(weight, bias=None),
process_group=weights.process_group,
......
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch
import enum
import os
......@@ -712,6 +715,7 @@ def get_model(
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION:
return FlashCausalLM(
model_id=model_id,
......@@ -856,7 +860,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
else:
if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashCausalLM(
......
......@@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
]
# Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple:
if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection
......@@ -377,7 +377,7 @@ class CausalLMBatch(Batch):
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple:
if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values
......
......@@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
loss = None
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return (
......
from typing import Optional, Tuple, Union
from typing import Optional, Tuple
import torch
from torch import nn
......@@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
)
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
......@@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module):
def __init__(self, prefix: str, config: CLIPTextConfig):
def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config)
# Initialize weights and apply final processing with `self.post_init()`
self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
......@@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module):
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
......@@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module):
]
else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[
last_hidden_state[
torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device
),
......@@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.text_model(
input_ids=input_ids,
......@@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
......@@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.vision_model(
pixel_values=pixel_values,
......@@ -799,14 +791,12 @@ class CLIPModel(nn.Module):
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model(
pixel_values=pixel_values,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=return_dict,
)
image_embeds = vision_outputs[1]
......
......@@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
......
......@@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import (
FastLayerNorm,
)
from text_generation_server.utils.log import log_once
class DbrxAttentionConfig(PretrainedConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment