Unverified Commit 3f20877d authored by tomeras91's avatar tomeras91 Committed by GitHub
Browse files

Add jamba (#29943)

* Add jamba arch

* apply "make fix-copies" changes

* fix link to model in JambaConfig docstring

* Add n_ctx in modeling file because repo-consistency wants that

* Add jamba to flash attention and sdpa documentation

* mamba dt_proj quant fix now works for LoRA as well

* override test_left_padding_compatibility and use a more permissive tolerance. left padding numerical difference are accentuated by mamba layers

* add jamba to tokenization auto

* fix comments of shape (PR #24 in the model page: https://huggingface.co/ai21labs/Jamba-v0.1/discussions/24)

* simple PR fixes

* remove unnecessary kwargs from JambaAttentionDecoderLayer and JambaMambaDecoderLayer

* remove the LoRA hack for the mamba dt_proj bias. It was solved in huggingface/peft#1530 (https://github.com/huggingface/peft/pull/1530)

* Add copied comment on JambaMLP (it's the same as MixtralMLP)

* remove padding_mask warnings. It's not supported anymore

* fix docstring. Float instead of int

* A few more minor PR fixes

* (1) lowercase names for mamba layernorms (2) remove _apply_inner_layernorms and do it directly in the forward pass

* Return None attention weights from mamba layers. Append to all attentions only if not None.

* remove some leftover jamba archive lists

* Better separation between expert vs non-expert layers. non-expert layers return None as router_logits, and it is not concatenated to all_router_logits returned from JambaModel

* no need to take router_logits at config.expert_layer_offset anymore. result.router_logits now holds results only for expert layers

* Add Jamba paper on READMEs

* (1) rename n_ctx -> max_position_embeddings (2) don't use it in the modeling file since it's not needed (set it as an exception to check_config_attributes)

* Add copied from comment

* remove the code path for apply_inner_layernorms=False. Jamba always has the inner mamba layernorms

* clearer docstring for _convert_to_standard_cache

* style fixes

* Change calc_logits_for_entire_prompt (bool) to num_logits_to_keep (int). Adapt assisted decoding code tp use it. Also small change in low memory beam search decoding path to support this new int value in model_inputs

* rename test so it still overrides what its meant to override

* draft

* oups

* nit

* remove more complexe logic

* fix names used in config

* fix fix fix

* style

* fix some more failing tests

* generate did not init the cache 🙃



* more small nits

* typo

* config.mamba_expand * config.hidden_size for the intermediate size of the mamba shapes

* fix init of pkv with torch.tensor()

* empty tensor

* fix some init issues

* stupid changes required by generate because it does not even support it's own DynamicCache class

* more fixes

* fix general assisted gen cache_position bug

* tests passing

* Add offsets and periods as SPECIAL_CASES_TO_ALLOW in check_config_attributes.py

* fix reorder_cache to reorder mamba states and override some more functions in HybridMambaAttentionDynamicCache

* no need to override test_past_key_values_format() and _check_past_key_values_for_generate() in tests anymore

* fix docstrings and typehints for past_key_values

* style fixes

* fix docs

* change typehint due to copy from Mixtral

* forgot import

* import order

* Add configuration_jamba and modeling_jamba to not_doctested because the model is too big to download (in docstring of JambaForCausalLM.forward)

* Add integration test with tiny tandom Jamba model on hub

* fix flash attention cache shapes

* bring back forgotten hidden states

* rename HybridMambaAttentionDynamicCache.seqlen_offset to has_previous_state (and make bool) and bugfix - it should be set to True after a finished forward pass of the entire model

* align integration test after modeling fixes

* bugfix - mamba can use precomputed states only of forward pass is on a single token

* bugfix - mamba can use precomputed states only if they match the batch size

* typo

* remove making _prepare_4d_causal_attention_mask a leaf function

* stop using past_seq_len.get_seq_length(). Use cache positions instead. Adjust test (test_decoder_model_past_with_large_inputs) accordingly

---------
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarJoao Gante <joao@huggingface.co>
parent 28a22834
......@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
import torch
from ..cache_utils import DynamicCache
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
......@@ -371,7 +373,13 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
else:
elif isinstance(past_key_values, DynamicCache):
for idx in range(len(past_key_values.key_cache)):
if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
elif past_key_values is not None:
for idx in range(len(past_key_values)):
new_past.append(
(
......
......@@ -598,7 +598,11 @@ class GenerationMixin:
def _expand_dict_for_generation(dict_to_expand):
for key in dict_to_expand:
if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], torch.Tensor):
if (
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
......@@ -2094,7 +2098,8 @@ class GenerationMixin:
# Replicates the new past_key_values to match the `top_k` candidates
new_key_values = []
for layer in model_kwargs["past_key_values"]:
past = model_kwargs["past_key_values"]
for layer in past:
items = []
# item is either the key or the value matrix
for item in layer:
......@@ -2103,7 +2108,13 @@ class GenerationMixin:
else:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(tuple(items))
model_kwargs["past_key_values"] = tuple(new_key_values)
if not isinstance(past, DynamicCache):
past = tuple(new_key_values)
else:
for layer_idx in range(len(new_key_values)):
past.key_cache[layer_idx] = new_key_values[layer_idx][0]
past.value_cache[layer_idx] = new_key_values[layer_idx][1]
model_kwargs["past_key_values"] = past
if sequential:
all_outputs = []
......@@ -2178,16 +2189,22 @@ class GenerationMixin:
else:
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
new_key_values = ()
new_key_values = []
for layer in next_past_key_values:
items = ()
items = []
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items += (item,)
new_key_values += (items,)
next_past_key_values = new_key_values
items += [item]
new_key_values += [items]
if not isinstance(next_past_key_values, DynamicCache):
next_past_key_values = tuple(new_key_values)
else:
for layer_idx in range(len(new_key_values)):
next_past_key_values.key_cache[layer_idx] = new_key_values[layer_idx][0]
next_past_key_values.value_cache[layer_idx] = new_key_values[layer_idx][1]
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
......@@ -3127,6 +3144,7 @@ class GenerationMixin:
"transo_xl",
"xlnet",
"cpm",
"jamba",
]
):
raise RuntimeError(
......@@ -4645,21 +4663,22 @@ class GenerationMixin:
# we use this forward pass to also pick the subsequent logits in the original model.
# 2.1. Prepare the model inputs
candidate_kwargs = copy.copy(model_kwargs)
candidate_kwargs = _prepare_attention_mask(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
model_kwargs = _prepare_attention_mask(
model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1])
if "cache_position" in model_kwargs:
model_kwargs["cache_position"] = torch.cat(
(
candidate_kwargs["cache_position"],
model_kwargs["cache_position"],
torch.arange(cur_len, cur_len + candidate_length, device=input_ids.device, dtype=torch.long),
),
dim=0,
)
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)
model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **model_kwargs)
if "num_logits_to_keep" in model_inputs:
model_inputs["num_logits_to_keep"] = candidate_length + 1
# 2.2. Run a forward pass on the candidate sequence
outputs = self(
......@@ -4985,7 +5004,7 @@ def _split_model_inputs(
# ModelOutput object.
# bool should not be split but replicated for each split
bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"]
keys_to_ignore = ["cache_position", "encoder_outputs"]
keys_to_ignore = ["cache_position", "encoder_outputs", "num_logits_to_keep"]
non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore]
# we split the tensors and tuples of tensors
......@@ -5001,6 +5020,11 @@ def _split_model_inputs(
data_split_list = [
{**data_split, "encoder_outputs": encoder_outputs_split[i]} for i, data_split in enumerate(data_split_list)
]
# num_logits_to_keep should be replicated for each split, similar to bool values
if "num_logits_to_keep" in model_input:
data_split_list = [
{**data_split, "num_logits_to_keep": model_input["num_logits_to_keep"]} for data_split in data_split_list
]
# Convert each dictionary in the list to an object of the inferred class
split_model_inputs: List[Union[ModelOutput, Dict]] = [
......
......@@ -115,6 +115,7 @@ from . import (
imagegpt,
informer,
instructblip,
jamba,
jukebox,
kosmos2,
layoutlm,
......
......@@ -129,6 +129,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("imagegpt", "ImageGPTConfig"),
("informer", "InformerConfig"),
("instructblip", "InstructBlipConfig"),
("jamba", "JambaConfig"),
("jukebox", "JukeboxConfig"),
("kosmos-2", "Kosmos2Config"),
("layoutlm", "LayoutLMConfig"),
......@@ -397,6 +398,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("imagegpt", "ImageGPT"),
("informer", "Informer"),
("instructblip", "InstructBLIP"),
("jamba", "Jamba"),
("jukebox", "Jukebox"),
("kosmos-2", "KOSMOS-2"),
("layoutlm", "LayoutLM"),
......
......@@ -123,6 +123,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("idefics2", "Idefics2Model"),
("imagegpt", "ImageGPTModel"),
("informer", "InformerModel"),
("jamba", "JambaModel"),
("jukebox", "JukeboxModel"),
("kosmos-2", "Kosmos2Model"),
("layoutlm", "LayoutLMModel"),
......@@ -451,6 +452,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("gpt_neox", "GPTNeoXForCausalLM"),
("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"),
("gptj", "GPTJForCausalLM"),
("jamba", "JambaForCausalLM"),
("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"),
("marian", "MarianForCausalLM"),
......@@ -851,6 +853,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("gpt_neox", "GPTNeoXForSequenceClassification"),
("gptj", "GPTJForSequenceClassification"),
("ibert", "IBertForSequenceClassification"),
("jamba", "JambaForSequenceClassification"),
("layoutlm", "LayoutLMForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("layoutlmv3", "LayoutLMv3ForSequenceClassification"),
......
......@@ -203,6 +203,13 @@ else:
("idefics", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("idefics2", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("instructblip", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
(
"jamba",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("jukebox", ("JukeboxTokenizer", None)),
(
"kosmos-2",
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
"configuration_jamba": ["JambaConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_jamba"] = [
"JambaForCausalLM",
"JambaForSequenceClassification",
"JambaModel",
"JambaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_jamba import JambaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_jamba import (
JambaForCausalLM,
JambaForSequenceClassification,
JambaModel,
JambaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Jamba model configuration"""
import math
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class JambaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`JambaModel`]. It is used to instantiate a
Jamba model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Jamba-v0.1 model.
[ai21labs/Jamba-v0.1](https://huggingface.co/ai21labs/Jamba-v0.1)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 65536):
Vocabulary size of the Jamba model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`JambaModel`]
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 14336):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
significantly.
output_router_logits (`bool`, *optional*, defaults to `False`):
Whether or not the router logits should be returned by the model. Enabling this will also
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `None`.
max_position_embeddings (`int`, *optional*, defaults to 262144):
This value doesn't have any real effect. The maximum sequence length that this model is intended to be
used with. It can be used with longer sequences, but performance may degrade.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_experts_per_tok (`int`, *optional*, defaults to 2):
The number of experts to root per-token, can be also interpreted as the `top-p` routing
parameter
num_experts (`int`, *optional*, defaults to 16):
Number of experts per Sparse MLP layer.
expert_layer_period (`int`, *optional*, defaults to 2):
Once in this many layers, we will have an expert layer
expert_layer_offset (`int`, *optional*, defaults to 1):
The first layer index that contains an expert mlp layer
attn_layer_period (`int`, *optional*, defaults to 8):
Once in this many layers, we will have a vanilla attention layer
attn_layer_offset (`int`, *optional*, defaults to 4):
The first layer index that contains a vanilla attention mlp layer
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
`causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
`True` and kernels are not available
mamba_d_state (`int`, *optional*, defaults to 16):
The dimension the mamba state space latents
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the the mamba discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
"""
model_type = "jamba"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=65536,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=14336,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=8,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
num_logits_to_keep=1,
output_router_logits=False,
router_aux_loss_coef=0.001,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
sliding_window=None,
max_position_embeddings=262144,
attention_dropout=0.0,
num_experts_per_tok=2,
num_experts=16,
expert_layer_period=2,
expert_layer_offset=1,
attn_layer_period=8,
attn_layer_offset=4,
use_mamba_kernels=True,
mamba_d_state=16,
mamba_d_conv=4,
mamba_expand=2,
mamba_dt_rank="auto",
mamba_conv_bias=True,
mamba_proj_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.expert_layer_period = expert_layer_period
self.expert_layer_offset = expert_layer_offset
self.attn_layer_period = attn_layer_period
self.attn_layer_offset = attn_layer_offset
self.use_mamba_kernels = use_mamba_kernels
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_expand = mamba_expand
self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
self.mamba_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def layers_block_type(self):
return [
"attention" if i % self.attn_layer_period == self.attn_layer_offset else "mamba"
for i in range(self.num_hidden_layers)
]
@property
def layers_num_experts(self):
return [
self.num_experts if i % self.expert_layer_period == self.expert_layer_offset else 1
for i in range(self.num_hidden_layers)
]
This diff is collapsed.
......@@ -4526,6 +4526,34 @@ class InstructBlipVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class JambaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class JambaForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class JambaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class JambaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -1050,7 +1050,7 @@ class GenerationTesterMixin:
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
self.skipTest("Won't fix: old model with different cache format")
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
self.skipTest("TODO: fix me")
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
......@@ -1098,6 +1098,7 @@ class GenerationTesterMixin:
"transo_xl",
"xlnet",
"cpm",
"jamba",
]
):
self.skipTest("May fix in the future: need model-specific fixes")
......@@ -1735,11 +1736,12 @@ class GenerationTesterMixin:
use_cache=use_cache,
)
# Past Key Value States -- two notes here:
# Past Key Value States -- a few notes here:
# 1. Its inner sequence length is with respect to the inputs of the latest forward pass, hence the "-1"
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats, skipping those until the cache refactor is complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer")
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)
......
This diff is collapsed.
......@@ -32,6 +32,15 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = {
# 'max_position_embeddings' is not used in modeling file, but needed for eval frameworks like Huggingface's lighteval (https://github.com/huggingface/lighteval/blob/af24080ea4f16eaf1683e353042a2dfc9099f038/src/lighteval/models/base_model.py#L264).
# periods and offsers are not used in modeling file, but used in the configuration file to define `layers_block_type` and `layers_num_experts`.
"JambaConfig": [
"max_position_embeddings",
"attn_layer_offset",
"attn_layer_period",
"expert_layer_offset",
"expert_layer_period",
],
# used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"],
# used to compute the property `self.layers_block_type`
......
......@@ -631,6 +631,8 @@ src/transformers/models/instructblip/configuration_instructblip.py
src/transformers/models/instructblip/convert_instructblip_original_to_pytorch.py
src/transformers/models/instructblip/modeling_instructblip.py
src/transformers/models/instructblip/processing_instructblip.py
src/transformers/models/jamba/configuration_jamba.py
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jukebox/configuration_jukebox.py
src/transformers/models/jukebox/convert_jukebox.py
src/transformers/models/jukebox/modeling_jukebox.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment