Unverified Commit bab02ff2 authored by drbh's avatar drbh Committed by GitHub
Browse files

feat: add ruff and resolve issue (#2262)

* feat: add ruff and resolve issue

* fix: update client exports and adjust after rebase

* fix: adjust syntax to avoid circular import

* fix: adjust client ruff settings

* fix: lint and refactor import check and avoid model enum as global names

* fix: improve fbgemm_gpu check and lints

* fix: update lints

* fix: prefer comparing model enum over str

* fix: adjust lints and ignore specific rules

* fix: avoid unneeded quantize check
parent 4b49c50f
...@@ -39,6 +39,12 @@ from torch import nn ...@@ -39,6 +39,12 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
def __init__( def __init__(
......
...@@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import ( ...@@ -46,7 +46,6 @@ from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
UnquantizedWeight,
Weights, Weights,
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
...@@ -277,7 +276,7 @@ class LlamaMLP(nn.Module): ...@@ -277,7 +276,7 @@ class LlamaMLP(nn.Module):
bias=bias, bias=bias,
) )
else: else:
prefixes = [f"gate_proj", f"up_proj"] prefixes = ["gate_proj", "up_proj"]
sizes = [ sizes = [
config.intermediate_size, config.intermediate_size,
config.intermediate_size, config.intermediate_size,
......
...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple ...@@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
Seqlen,
paged_attention, paged_attention,
attention, attention,
reshape_and_cache, reshape_and_cache,
...@@ -38,7 +37,6 @@ from text_generation_server.layers import ( ...@@ -38,7 +37,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
......
...@@ -21,7 +21,6 @@ ...@@ -21,7 +21,6 @@
import torch import torch
import torch.distributed import torch.distributed
import numpy as np
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
...@@ -31,7 +30,6 @@ if SYSTEM != "ipex": ...@@ -31,7 +30,6 @@ if SYSTEM != "ipex":
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from loguru import logger
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
import torch import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
......
...@@ -15,7 +15,6 @@ from text_generation_server.layers import ( ...@@ -15,7 +15,6 @@ from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Idefics2 model.""" """ PyTorch Idefics2 model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -22,10 +22,8 @@ from torch import nn ...@@ -22,10 +22,8 @@ from torch import nn
import math import math
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model,
) )
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from PIL import Image from PIL import Image
import transformers
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import ( from transformers.image_transforms import (
resize, resize,
...@@ -293,6 +294,4 @@ class IdeficsImageProcessor(BaseImageProcessor): ...@@ -293,6 +294,4 @@ class IdeficsImageProcessor(BaseImageProcessor):
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor transformers.IdeficsImageProcessor = IdeficsImageProcessor
...@@ -21,10 +21,8 @@ ...@@ -21,10 +21,8 @@
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
...@@ -33,13 +31,6 @@ from transformers.modeling_outputs import ( ...@@ -33,13 +31,6 @@ from transformers.modeling_outputs import (
CausalLMOutputWithPast, CausalLMOutputWithPast,
dataclass, dataclass,
) )
from transformers.modeling_utils import PretrainedConfig
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_vision import ( from text_generation_server.models.custom_modeling.idefics_vision import (
IdeficsVisionTransformer, IdeficsVisionTransformer,
...@@ -56,6 +47,7 @@ from text_generation_server.layers import ( ...@@ -56,6 +47,7 @@ from text_generation_server.layers import (
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from loguru import logger
if SYSTEM == "cuda": if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
...@@ -237,7 +229,7 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module): ...@@ -237,7 +229,7 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
prefix="model.embed_tokens", weights=weights prefix="model.embed_tokens", weights=weights
) )
self.additional_weight = nn.Parameter( self.additional_weight = nn.Parameter(
weights.get_tensor(f"model.embed_tokens.additional_embedding.weight") weights.get_tensor("model.embed_tokens.additional_embedding.weight")
) )
def forward(self, input_ids): def forward(self, input_ids):
...@@ -499,7 +491,6 @@ class IdeficsAttention(nn.Module): ...@@ -499,7 +491,6 @@ class IdeficsAttention(nn.Module):
# if not hasattr(nn.functional, "scaled_dot_product_attention"): # if not hasattr(nn.functional, "scaled_dot_product_attention"):
# raise ValueError("this model requires pytorch 2.0 or higher") # raise ValueError("this model requires pytorch 2.0 or higher")
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
...@@ -1024,7 +1015,7 @@ class IdeficsModel(IdeficsPreTrainedModel): ...@@ -1024,7 +1015,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
if config.use_resampler: if config.use_resampler:
perceiver_config = config.perceiver_config perceiver_config = config.perceiver_config
self.perceiver_resampler = IdeficsPerceiverResampler( self.perceiver_resampler = IdeficsPerceiverResampler(
prefix=f"model.perceiver_resampler", prefix="model.perceiver_resampler",
config=config, config=config,
embed_dim=config.vision_config.embed_dim, embed_dim=config.vision_config.embed_dim,
depth=perceiver_config.resampler_depth, depth=perceiver_config.resampler_depth,
...@@ -1052,7 +1043,7 @@ class IdeficsModel(IdeficsPreTrainedModel): ...@@ -1052,7 +1043,7 @@ class IdeficsModel(IdeficsPreTrainedModel):
# self.gradient_checkpointing = False # self.gradient_checkpointing = False
self.norm = IdeficsRMSNorm( self.norm = IdeficsRMSNorm(
prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps prefix="model.norm", weights=weights, eps=config.rms_norm_eps
) )
# self.gradient_checkpointing = False # self.gradient_checkpointing = False
......
...@@ -169,7 +169,6 @@ class IdeficsPerceiverAttention(nn.Module): ...@@ -169,7 +169,6 @@ class IdeficsPerceiverAttention(nn.Module):
self.qk_scale = self.head_dim**-0.5 self.qk_scale = self.head_dim**-0.5
process_group = weights.process_group
if n_heads % weights.process_group.size() != 0: if n_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {n_heads} "
......
...@@ -28,9 +28,6 @@ from transformers.tokenization_utils_base import ( ...@@ -28,9 +28,6 @@ from transformers.tokenization_utils_base import (
TruncationStrategy, TruncationStrategy,
) )
from transformers.utils import TensorType, is_torch_available from transformers.utils import TensorType, is_torch_available
from text_generation_server.models.custom_modeling.idefics_image_processing import (
IdeficsImageProcessor,
)
if is_torch_available(): if is_torch_available():
......
...@@ -129,7 +129,6 @@ class IdeficsVisionAttention(nn.Module): ...@@ -129,7 +129,6 @@ class IdeficsVisionAttention(nn.Module):
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout self.dropout = config.attention_dropout
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
...@@ -460,7 +459,6 @@ class IdeficsVisionTransformer(nn.Module): ...@@ -460,7 +459,6 @@ class IdeficsVisionTransformer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size
self.embeddings = IdeficsVisionEmbeddings( self.embeddings = IdeficsVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights prefix=f"{prefix}.embeddings", config=config, weights=weights
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch Llava-NeXT model.""" """ PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
......
...@@ -4,7 +4,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py ...@@ -4,7 +4,6 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
""" """
import math import math
import os
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -194,7 +193,7 @@ def flash_attn_fn( ...@@ -194,7 +193,7 @@ def flash_attn_fn(
): ):
try: try:
from flash_attn import bert_padding, flash_attn_interface from flash_attn import bert_padding, flash_attn_interface
except: except Exception:
raise RuntimeError("Please install flash-attn==1.0.3.post0") raise RuntimeError("Please install flash-attn==1.0.3.post0")
check_valid_inputs(query, key, value) check_valid_inputs(query, key, value)
if past_key_value is not None: if past_key_value is not None:
...@@ -207,7 +206,7 @@ def flash_attn_fn( ...@@ -207,7 +206,7 @@ def flash_attn_fn(
_s_k = max(0, attn_bias.size(3) - key.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:] attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if attn_bias is not None: if attn_bias is not None:
raise NotImplementedError(f"attn_bias not implemented for flash attn.") raise NotImplementedError("attn_bias not implemented for flash attn.")
(batch_size, seqlen) = query.shape[:2] (batch_size, seqlen) = query.shape[:2]
if key_padding_mask is None: if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
...@@ -269,13 +268,13 @@ def triton_flash_attn_fn( ...@@ -269,13 +268,13 @@ def triton_flash_attn_fn(
): ):
try: try:
from .flash_attn_triton import flash_attn_func from .flash_attn_triton import flash_attn_func
except: except Exception:
_installed = False _installed = False
if version.parse(torch.__version__) < version.parse("2.0.0"): if version.parse(torch.__version__) < version.parse("2.0.0"):
_installed = True _installed = True
try: try:
from flash_attn.flash_attn_triton import flash_attn_func from flash_attn.flash_attn_triton import flash_attn_func
except: except Exception:
_installed = False _installed = False
if not _installed: if not _installed:
raise RuntimeError( raise RuntimeError(
...@@ -292,9 +291,9 @@ def triton_flash_attn_fn( ...@@ -292,9 +291,9 @@ def triton_flash_attn_fn(
_s_k = max(0, attn_bias.size(3) - key.size(1)) _s_k = max(0, attn_bias.size(3) - key.size(1))
attn_bias = attn_bias[:, :, _s_q:, _s_k:] attn_bias = attn_bias[:, :, _s_q:, _s_k:]
if dropout_p: if dropout_p:
raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") raise NotImplementedError("Dropout not implemented for attn_impl: triton.")
if needs_weights: if needs_weights:
raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") raise NotImplementedError("attn_impl: triton cannot return attn weights.")
if key_padding_mask is not None: if key_padding_mask is not None:
warnings.warn( warnings.warn(
"Propagating key_padding_mask to the attention module " "Propagating key_padding_mask to the attention module "
...@@ -428,7 +427,7 @@ class MultiQueryAttention(nn.Module): ...@@ -428,7 +427,7 @@ class MultiQueryAttention(nn.Module):
additive bias. additive bias.
""" """
def __init__(self, config, prefix, weights): def __init__(self, config, prefix, weights, verbose=False):
super().__init__() super().__init__()
attn_impl = config.attn_config.attn_impl attn_impl = config.attn_config.attn_impl
self.attn_impl = config.attn_config.attn_impl self.attn_impl = config.attn_config.attn_impl
...@@ -445,7 +444,7 @@ class MultiQueryAttention(nn.Module): ...@@ -445,7 +444,7 @@ class MultiQueryAttention(nn.Module):
self.Wqkv = TensorParallelColumnLinear.load( self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
) )
fuse_splits = (d_model, d_model + self.head_dim) (d_model, d_model + self.head_dim)
if self.qk_ln: if self.qk_ln:
raise NotImplementedError("qk_ln not supported") raise NotImplementedError("qk_ln not supported")
if self.attn_impl == "flash": if self.attn_impl == "flash":
...@@ -795,7 +794,9 @@ class MPTModel(MPTPreTrainedModel): ...@@ -795,7 +794,9 @@ class MPTModel(MPTPreTrainedModel):
self.alibi = config.attn_config.alibi self.alibi = config.attn_config.alibi
self.alibi_bias_max = config.attn_config.alibi_bias_max self.alibi_bias_max = config.attn_config.alibi_bias_max
if config.init_device == "mixed": if config.init_device == "mixed":
if dist.get_local_rank() == 0: # TODO: reimplement mixed device initialization
# dist.get_local_rank() == 0:
if True:
config.init_device = "cpu" config.init_device = "cpu"
else: else:
config.init_device = "meta" config.init_device = "meta"
...@@ -1016,7 +1017,7 @@ class MPTModel(MPTPreTrainedModel): ...@@ -1016,7 +1017,7 @@ class MPTModel(MPTPreTrainedModel):
if past_key_values is not None: if past_key_values is not None:
if len(past_key_values) != self.config.n_layers: if len(past_key_values) != self.config.n_layers:
raise ValueError( raise ValueError(
f"past_key_values must provide a past_key_value for each attention " "past_key_values must provide a past_key_value for each attention "
+ f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})." + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
) )
past_position = past_key_values[0][0].size(1) past_position = past_key_values[0][0].size(1)
...@@ -1182,7 +1183,7 @@ class MPTForCausalLM(MPTPreTrainedModel): ...@@ -1182,7 +1183,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
if self.transformer.prefix_lm: if self.transformer.prefix_lm:
prefix_mask = torch.ones_like(attention_mask) prefix_mask = torch.ones_like(attention_mask)
if kwargs.get("use_cache") == False: if kwargs.get("use_cache") is False:
raise NotImplementedError( raise NotImplementedError(
"MPT with prefix_lm=True does not support use_cache=False." "MPT with prefix_lm=True does not support use_cache=False."
) )
......
...@@ -21,25 +21,14 @@ import torch ...@@ -21,25 +21,14 @@ import torch
import torch.distributed import torch.distributed
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.file_utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig
from loguru import logger
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
...@@ -133,7 +122,6 @@ class GPTNeoXAttention(nn.Module): ...@@ -133,7 +122,6 @@ class GPTNeoXAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct) self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
# ??? TODO # ??? TODO
# self.register_buffer( # self.register_buffer(
# "bias", # "bias",
......
...@@ -5,7 +5,7 @@ import torch.distributed ...@@ -5,7 +5,7 @@ import torch.distributed
import math import math
from torch import nn from torch import nn
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
......
from typing import Optional, Tuple, Union from typing import Optional, Tuple
import warnings
import math import math
import torch import torch
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import (
_create_4d_causal_attention_mask,
_prepare_4d_attention_mask,
)
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
ImageClassifierOutput,
) )
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers import SiglipConfig, SiglipVisionConfig
from torch.nn.init import _calculate_fan_in_and_fan_out
from text_generation_server.layers.tensor_parallel import ( from text_generation_server.layers.tensor_parallel import (
TensorParallelEmbedding, TensorParallelEmbedding,
...@@ -244,9 +239,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module): ...@@ -244,9 +239,6 @@ class SiglipMultiheadAttentionPoolingHead(nn.Module):
return hidden_state[:, 0] return hidden_state[:, 0]
import warnings
def _trunc_normal_(tensor, mean, std, a, b): def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW # Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
...@@ -264,12 +256,12 @@ def _trunc_normal_(tensor, mean, std, a, b): ...@@ -264,12 +256,12 @@ def _trunc_normal_(tensor, mean, std, a, b):
# Values are generated by using a truncated uniform distribution and # Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution. # then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values # Get upper and lower cdf values
l = norm_cdf((a - mean) / std) lower = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std) upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to # Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1]. # [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1) tensor.uniform_(2 * lower - 1, 2 * upper - 1)
# Use inverse cdf transform for normal distribution to get truncated # Use inverse cdf transform for normal distribution to get truncated
# standard normal # standard normal
...@@ -313,9 +305,6 @@ def trunc_normal_tf_( ...@@ -313,9 +305,6 @@ def trunc_normal_tf_(
tensor.mul_(std).add_(mean) tensor.mul_(std).add_(mean)
from torch.nn.init import _calculate_fan_in_and_fan_out
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in": if mode == "fan_in":
...@@ -349,9 +338,6 @@ def default_flax_embed_init(tensor): ...@@ -349,9 +338,6 @@ def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal") variance_scaling_(tensor, mode="fan_in", distribution="normal")
from transformers import PreTrainedModel
class SiglipEncoder(nn.Module): class SiglipEncoder(nn.Module):
""" """
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
...@@ -393,7 +379,6 @@ class SiglipVisionTransformer(nn.Module): ...@@ -393,7 +379,6 @@ class SiglipVisionTransformer(nn.Module):
def __init__(self, prefix, config: SiglipVisionConfig, weights): def __init__(self, prefix, config: SiglipVisionConfig, weights):
super().__init__() super().__init__()
self.config = config self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings( self.embeddings = SiglipVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights prefix=f"{prefix}.embeddings", config=config, weights=weights
......
...@@ -45,6 +45,15 @@ from text_generation_server.layers import ( ...@@ -45,6 +45,15 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
) )
# copied from https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/t5/modeling_t5.py#L1316
# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""
class PartialTPEmbedding(nn.Module): class PartialTPEmbedding(nn.Module):
def __init__(self, prefix: str, weights): def __init__(self, prefix: str, weights):
...@@ -1132,12 +1141,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1132,12 +1141,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100) loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP # move labels to correct device to enable PP
labels = labels.to(lm_logits.device) labels = labels.to(logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if not return_dict: if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs output = (logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return ( return (
......
...@@ -42,7 +42,7 @@ def load_vision_model(prefix, config, weights): ...@@ -42,7 +42,7 @@ def load_vision_model(prefix, config, weights):
) )
return SiglipVisionTransformer( return SiglipVisionTransformer(
prefix=f"vision_tower.vision_model", config=config, weights=weights prefix="vision_tower.vision_model", config=config, weights=weights
) )
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")
...@@ -1194,7 +1194,7 @@ class FlashCausalLM(Model): ...@@ -1194,7 +1194,7 @@ class FlashCausalLM(Model):
if self.speculate is None or self.speculate + 1 <= bs: if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt) self.cuda_graph_warmup(bs, max_s, max_bt)
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed") logger.exception("Decode cuda graph warmup failed")
else: else:
log_master( log_master(
logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})." logger.info, f"Cuda Graphs are disabled (CUDA_GRAPHS={CUDA_GRAPHS})."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment