Unverified Commit bab02ff2 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
parent 4b49c50f
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
import math
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
......
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
......
...@@ -12,17 +12,26 @@ from text_generation_server.utils.weights import ( ...@@ -12,17 +12,26 @@ from text_generation_server.utils.weights import (
Weights, Weights,
) )
from text_generation_server.utils.log import log_master, log_once from text_generation_server.utils.log import log_master, log_once
import importlib.util
FBGEMM_MM_AVAILABLE = False FBGEMM_MM_AVAILABLE = False
FBGEMM_DYN_AVAILABLE = False FBGEMM_DYN_AVAILABLE = False
try:
import fbgemm_gpu.experimental.gen_ai
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
if is_fbgemm_gpu_available():
if SYSTEM == "cuda": if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability() major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9 FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8 FBGEMM_DYN_AVAILABLE = major >= 8
except (ImportError, ModuleNotFoundError): else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.") log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
......
...@@ -8,6 +8,34 @@ from text_generation_server.utils.import_utils import SYSTEM ...@@ -8,6 +8,34 @@ from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
)
HAS_EXLLAMA = "1"
except ImportError:
pass
@dataclass @dataclass
class GPTQWeight(Weight): class GPTQWeight(Weight):
...@@ -55,7 +83,7 @@ class GPTQWeight(Weight): ...@@ -55,7 +83,7 @@ class GPTQWeight(Weight):
from text_generation_server.layers.gptq import ExllamaQuantLinear from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError: except ImportError:
raise NotImplementedError( raise NotImplementedError(
f"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`" "Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
) )
return ExllamaQuantLinear(self, bias) return ExllamaQuantLinear(self, bias)
...@@ -73,45 +101,6 @@ class GPTQWeight(Weight): ...@@ -73,45 +101,6 @@ class GPTQWeight(Weight):
) )
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllamav2 import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear,
)
from text_generation_server.layers.gptq.exllama import (
create_exllama_buffers,
set_device,
)
HAS_EXLLAMA = "1"
except ImportError:
pass
from text_generation_server.layers.gptq.quant_linear import QuantLinear
class GPTQWeightsLoader(WeightsLoader): class GPTQWeightsLoader(WeightsLoader):
""" """
Loader for GPTQ- and AWQ-quantized weights. Loader for GPTQ- and AWQ-quantized weights.
......
...@@ -206,10 +206,13 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): ...@@ -206,10 +206,13 @@ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
output = torch.empty( output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
) )
grid = lambda META: (
def grid(META):
return (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]), * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
) )
matmul_248_kernel[grid]( matmul_248_kernel[grid](
input, input,
qweight, qweight,
......
...@@ -15,6 +15,7 @@ from text_generation_server.utils.hub import weight_files ...@@ -15,6 +15,7 @@ from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq.quant_linear import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error
from text_generation_server.utils.weights import DefaultWeightsLoader from text_generation_server.utils.weights import DefaultWeightsLoader
...@@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code): ...@@ -372,7 +373,7 @@ def get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
...@@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code): ...@@ -404,7 +405,7 @@ def get_ptb(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
...@@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code): ...@@ -448,7 +449,7 @@ def get_c4(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
...@@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code): ...@@ -504,7 +505,7 @@ def get_ptb_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
...@@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): ...@@ -546,7 +547,7 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=False, trust_remote_code=trust_remote_code model_id, use_fast=False, trust_remote_code=trust_remote_code
) )
except: except Exception:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, use_fast=True, trust_remote_code=trust_remote_code model_id, use_fast=True, trust_remote_code=trust_remote_code
) )
...@@ -700,6 +701,8 @@ def sequential( ...@@ -700,6 +701,8 @@ def sequential(
pass pass
def add_batch(name): def add_batch(name):
nonlocal gptq
def tmp(_, inp, out): def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data) gptq[name].add_batch(inp[0].data, out.data)
......
import torch
# copied from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
def torch_snr_error(
y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean"
) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calcualted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
if y_pred.shape != y_real.shape:
raise ValueError(
f"Can not compute snr loss for tensors with different shape. "
f"({y_pred.shape} and {y_real.shape})"
)
reduction = str(reduction).lower()
if y_pred.ndim == 1:
y_pred = y_pred.unsqueeze(0)
y_real = y_real.unsqueeze(0)
y_pred = y_pred.flatten(start_dim=1)
y_real = y_real.flatten(start_dim=1)
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == "mean":
return torch.mean(snr)
elif reduction == "sum":
return torch.sum(snr)
elif reduction == "none":
return snr
else:
raise ValueError("Unsupported reduction method.")
from typing import Optional
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F from torch.nn import functional as F
......
import math from typing import TYPE_CHECKING, Optional, List
import os
from typing import TYPE_CHECKING, Optional, Tuple, List
import torch import torch
import torch.distributed import torch.distributed
from accelerate import init_empty_weights
from torch import nn from torch import nn
from torch.nn import functional as F
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from text_generation_server.utils.sgmv import ( from text_generation_server.utils.sgmv import (
......
from typing import List, Tuple
import torch
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import ( from text_generation_server.layers.marlin.gptq import (
GPTQMarlinLinear, GPTQMarlinLinear,
......
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader): ...@@ -85,7 +85,7 @@ class MarlinWeightsLoader(WeightsLoader):
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized" "Cannot load `marlin` weight, make sure the model is already quantized"
) )
B_meta = torch.cat( B_meta = torch.cat(
...@@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader): ...@@ -104,7 +104,7 @@ class MarlinWeightsLoader(WeightsLoader):
) )
except RuntimeError: except RuntimeError:
raise RuntimeError( raise RuntimeError(
f"Cannot load `marlin` weight, make sure the model is already quantized" "Cannot load `marlin` weight, make sure the model is already quantized"
) )
s = torch.cat( s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
......
...@@ -2,12 +2,9 @@ import os ...@@ -2,12 +2,9 @@ import os
import math import math
import torch import torch
from torch import nn from torch import nn
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops from vllm._C import ops
......
...@@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module): ...@@ -33,7 +33,7 @@ class SpeculativeHead(torch.nn.Module):
except KeyError: except KeyError:
try: try:
speculator = MedusaHeadV1.load(config, prefix, weights) speculator = MedusaHeadV1.load(config, prefix, weights)
except: except Exception:
speculator = MedusaHeadV2(config, prefix, weights) speculator = MedusaHeadV2(config, prefix, weights)
lm_head = None lm_head = None
else: else:
......
...@@ -2,7 +2,6 @@ import torch ...@@ -2,7 +2,6 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from typing import Iterable, List from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex": if SYSTEM == "ipex":
...@@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer): ...@@ -50,7 +49,7 @@ class TensorParallelHead(SuperLayer):
# If the piece and LM head embeddings are shared, we have # If the piece and LM head embeddings are shared, we have
# non-quantized weights... # non-quantized weights...
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
except: except Exception:
# ...otherwise they are quantized. # ...otherwise they are quantized.
weight = weights.get_weights_col(prefix) weight = weights.get_weights_col(prefix)
should_gather = weights.process_group.size() > 1 should_gather = weights.process_group.size() > 1
...@@ -67,15 +66,6 @@ class TensorParallelHead(SuperLayer): ...@@ -67,15 +66,6 @@ class TensorParallelHead(SuperLayer):
weight = weights.get_tensor(f"{prefix}.weight") weight = weights.get_tensor(f"{prefix}.weight")
should_gather = False should_gather = False
# GPTQ,AWQ,EETQ don't quantize heads (nor embeddings)
if config.quantize in ["gptq", "awq", "eetq", "marlin"]:
quantize = None
# See above, exl2 LM head can be quantized or not.
elif config.quantize == "exl2" and not isinstance(weight, Exl2Weight):
quantize = None
else:
quantize = config.quantize
return TensorParallelHead( return TensorParallelHead(
get_linear(weight, bias=None), get_linear(weight, bias=None),
process_group=weights.process_group, process_group=weights.process_group,
......
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch import torch
import enum import enum
import os import os
...@@ -712,6 +715,7 @@ def get_model( ...@@ -712,6 +715,7 @@ def get_model(
) )
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
print(f">>> model_type: {model_type}")
if FLASH_ATTENTION: if FLASH_ATTENTION:
return FlashCausalLM( return FlashCausalLM(
model_id=model_id, model_id=model_id,
...@@ -856,7 +860,7 @@ def get_model( ...@@ -856,7 +860,7 @@ def get_model(
lora_adapter_ids=lora_adapter_ids, lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig, config_class=RWConfig,
) )
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Falcon"))
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashCausalLM( return FlashCausalLM(
......
...@@ -233,7 +233,7 @@ class CausalLMBatch(Batch): ...@@ -233,7 +233,7 @@ class CausalLMBatch(Batch):
] ]
# Ensure that past_key_values tensors can be updated in-place # Ensure that past_key_values tensors can be updated in-place
if type(self.past_key_values[0]) == tuple: if type(self.past_key_values[0]) is tuple:
self.past_key_values = [list(layer) for layer in self.past_key_values] self.past_key_values = [list(layer) for layer in self.past_key_values]
# Update tensors in-place to allow incremental garbage collection # Update tensors in-place to allow incremental garbage collection
...@@ -377,7 +377,7 @@ class CausalLMBatch(Batch): ...@@ -377,7 +377,7 @@ class CausalLMBatch(Batch):
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length] # BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim] # BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place # And ensure that we can update tensors in-place
if type(batch.past_key_values[0]) == tuple: if isinstance(batch.past_key_values[0], tuple):
batch.past_key_values = [ batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] [t.view(len(batch), -1, *t.shape[-2:]) for t in layer]
for layer in batch.past_key_values for layer in batch.past_key_values
......
...@@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -908,7 +908,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
loss = None loss = None
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return ( return (
......
from typing import Optional, Tuple, Union from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import ( ...@@ -9,9 +9,7 @@ from transformers.modeling_attn_mask_utils import (
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
) )
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
ImageClassifierOutput,
) )
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
...@@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module): ...@@ -446,11 +444,12 @@ class CLIPEncoder(nn.Module):
class CLIPTextTransformer(nn.Module): class CLIPTextTransformer(nn.Module):
def __init__(self, prefix: str, config: CLIPTextConfig): def __init__(self, prefix: str, config: CLIPTextConfig, weights=None):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.embeddings = CLIPTextEmbeddings(config) self.embeddings = CLIPTextEmbeddings(config)
# Initialize weights and apply final processing with `self.post_init()`
self.encoder = CLIPEncoder( self.encoder = CLIPEncoder(
prefix=f"{prefix}.encoder", config=config, weights=weights prefix=f"{prefix}.encoder", config=config, weights=weights
) )
...@@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module): ...@@ -505,7 +504,7 @@ class CLIPTextTransformer(nn.Module):
# text_embeds.shape = [batch_size, sequence_length, transformer.width] # text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence) # take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[ last_hidden_state[
torch.arange( torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device last_hidden_state.shape[0], device=last_hidden_state.device
), ),
...@@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module): ...@@ -515,7 +514,7 @@ class CLIPTextTransformer(nn.Module):
] ]
else: else:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output = last_hidden_state[ last_hidden_state[
torch.arange( torch.arange(
last_hidden_state.shape[0], device=last_hidden_state.device last_hidden_state.shape[0], device=last_hidden_state.device
), ),
...@@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -565,9 +564,6 @@ class CLIPTextModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state >>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
```""" ```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.text_model( return self.text_model(
input_ids=input_ids, input_ids=input_ids,
...@@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module): ...@@ -580,7 +576,6 @@ class CLIPVisionTransformer(nn.Module):
def __init__(self, prefix, config: CLIPVisionConfig, weights): def __init__(self, prefix, config: CLIPVisionConfig, weights):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size
self.embeddings = CLIPVisionEmbeddings( self.embeddings = CLIPVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights prefix=f"{prefix}.embeddings", config=config, weights=weights
...@@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel): ...@@ -661,9 +656,6 @@ class CLIPVisionModel(CLIPPreTrainedModel):
>>> last_hidden_state = outputs.last_hidden_state >>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled CLS states >>> pooled_output = outputs.pooler_output # pooled CLS states
```""" ```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return self.vision_model( return self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
...@@ -799,14 +791,12 @@ class CLIPModel(nn.Module): ...@@ -799,14 +791,12 @@ class CLIPModel(nn.Module):
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
vision_outputs = self.vision_model( vision_outputs = self.vision_model(
pixel_values=pixel_values, pixel_values=pixel_values,
return_dict=return_dict,
) )
text_outputs = self.text_model( text_outputs = self.text_model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=position_ids, position_ids=position_ids,
return_dict=return_dict,
) )
image_embeds = vision_outputs[1] image_embeds = vision_outputs[1]
......
...@@ -30,7 +30,6 @@ from text_generation_server.layers.attention import ( ...@@ -30,7 +30,6 @@ from text_generation_server.layers.attention import (
attention, attention,
reshape_and_cache, reshape_and_cache,
) )
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
......
...@@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import ( ...@@ -44,7 +44,6 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from text_generation_server.utils.log import log_once
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment