Unverified Commit 80b90e7b authored by Pablo Montalvo's avatar Pablo Montalvo Committed by GitHub
Browse files

Add codestral mamba2 (#32080)

* add new model like

* draft cuda forward - mismatched keys (sharding on conv1)

* match keys successfully

* fix split

* get generation/forward running (wrong gens, norm?)

* :update

* some refactoring

* fixes

* works up until copy to cache

* fix

* update

* NON WORKING VERSION

* version that work?

* nit

* fix config

* fix conversion script

* working cuda forward

* nit

* update

* simplifcation

* make mamba slow simple work

* no einops

* todo

* fix style

* no einops

* update fix no einsum

* nit

* remove einops

* bug: scan_output differs strongly

* add rms norm option

* fix fast + slow generation with and w/o cache 



* draft integration tests

* remove a big chunk of the einsum

* fix slow, fast generations, without any einsum

* fix copies

* fix structure

* fix up modeling and tests

* fix tests

* clamping is indeed worse

* recover mamba2 cache test

* fix copies

* no cache position (yet)

* fix tf tests

* fix matmul for generate

* fixup

* skip cache tests for now

* [run-slow]mamba2

* tune out hidden states for padding

* test batched generation

* propagate attention mask changes

* fix past length

* fix integration test

* style

* address comments

* update readme

* add mamba2 version check

* fix tests

* [run-slow]mamba2

* skip edge tests

* [run-slow]mamba2

* last fixup

* [run-slow]mamba2

* update README

---------
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
parent 3d8bd119
...@@ -438,6 +438,8 @@ ...@@ -438,6 +438,8 @@
title: MADLAD-400 title: MADLAD-400
- local: model_doc/mamba - local: model_doc/mamba
title: Mamba title: Mamba
- local: model_doc/mamba2
title: mamba2
- local: model_doc/marian - local: model_doc/marian
title: MarianMT title: MarianMT
- local: model_doc/markuplm - local: model_doc/markuplm
......
...@@ -194,6 +194,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -194,6 +194,7 @@ Flax), PyTorch, and/or TensorFlow.
| [M2M100](model_doc/m2m_100) | ✅ | ❌ | ❌ | | [M2M100](model_doc/m2m_100) | ✅ | ❌ | ❌ |
| [MADLAD-400](model_doc/madlad-400) | ✅ | ✅ | ✅ | | [MADLAD-400](model_doc/madlad-400) | ✅ | ✅ | ✅ |
| [Mamba](model_doc/mamba) | ✅ | ❌ | ❌ | | [Mamba](model_doc/mamba) | ✅ | ❌ | ❌ |
| [mamba2](model_doc/mamba2) | ✅ | ❌ | ❌ |
| [Marian](model_doc/marian) | ✅ | ✅ | ✅ | | [Marian](model_doc/marian) | ✅ | ✅ | ✅ |
| [MarkupLM](model_doc/markuplm) | ✅ | ❌ | ❌ | | [MarkupLM](model_doc/markuplm) | ✅ | ❌ | ❌ |
| [Mask2Former](model_doc/mask2former) | ✅ | ❌ | ❌ | | [Mask2Former](model_doc/mask2former) | ✅ | ❌ | ❌ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Mamba 2
## Overview
The Mamba2 model was proposed in [Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality](https://arxiv.org/abs/2405.21060) by Tri Dao and Albert Gu. It is a State Space Model similar to Mamba 1, with better performances in a simplified architecture.
The abstract from the paper is the following:
*While Transformers have been the main architecture behind deep learning's success in language modeling, state-space models (SSMs) such as Mamba have recently been shown to match or outperform Transformers at small to medium scale. We show that these families of models are actually quite closely related, and develop a rich framework of theoretical connections between SSMs and variants of attention, connected through various decompositions of a well-studied class of structured semiseparable matrices. Our state space duality (SSD) framework allows us to design a new architecture (Mamba-2) whose core layer is an a refinement of Mamba's selective SSM that is 2-8X faster, while continuing to be competitive with Transformers on language modeling.*
Tips:
This version should support all implementations of Mamba 2, and in particular [Mamba-2 codestral](https://huggingface.co/mistralai/Mamba-Codestral-7B-v0.1) from Mistral AI. In particular, mamba 2 codestral was released with a number of `groups` equal to 8, which can be thought intuitively as similar to the number of kv heads in an attention-based model.
This model has two different forward passes, `torch_forward` or `cuda_kernels_forward`. The latter uses the original cuda kernels if they are found in your environment, and is slower on the prefill i.e. requires a "warmup run" due to high cpu overhead, see [here](https://github.com/state-spaces/mamba/issues/389#issuecomment-2171755306) and [also here](https://github.com/state-spaces/mamba/issues/355#issuecomment-2147597457). Without compilation, the `torch_forward` implementation is faster by a factor 3 to 4. Further, there are no positional embeddings in this model, but there is an `attention_mask` and a specific logic to mask out hidden states in two places in the case of batched generation, see [here](https://github.com/state-spaces/mamba/issues/66#issuecomment-1863563829) as well. Due to this, in addition to the reimplementation of mamba2 kernels, batched generation and cached generation are expected to have slight discrepancies. Further, the results given by the cuda kernels or the torch forward are expected to be slightly different. The SSM algorithm heavily relies on tensor contractions, which have matmul equivalents but the order of operations is slightly different, making the difference greater at smaller precisions.
Another note, shutdown of hidden states corresponding to padding tokens is done in 2 places and mostly has been tested with left-padding. Right-padding will propagate noise down the line and is not guaranteed to yield satisfactory results. `tokenizer.padding_side = "left"` ensures you are using the correct padding side.
This model was contributed by [Molbap](https://huggingface.co/Molbap), with tremendous help from [Anton Vlasjuk](https://github.com/vasqu).
The original code can be found [here](https://github.com/state-spaces/mamba).
# Usage
### A simple generation example:
```python
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
model = MambaForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
Here's a draft script for finetuning:
```python
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, Mamba2ForCausalLM, TrainingArguments
model_id = 'mistralai/Mamba-Codestral-7B-v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_id, revision='refs/pr/9', from_slow=True, legacy=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left" #enforce padding side left
model = Mamba2ForCausalLM.from_pretrained(model_id, revision='refs/pr/9')
dataset = load_dataset("Abirate/english_quotes", split="train")
# Without CUDA kernels, batch size of 2 occupies one 80GB device
# but precision can be reduced.
# Experiments and trials welcome!
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=2,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
```
## Mamba2Config
[[autodoc]] Mamba2Config
## Mamba2Model
[[autodoc]] Mamba2Model
- forward
## Mamba2LMHeadModel
[[autodoc]] Mamba2ForCausalLM
- forward
...@@ -544,6 +544,7 @@ _import_structure = { ...@@ -544,6 +544,7 @@ _import_structure = {
], ],
"models.m2m_100": ["M2M100Config"], "models.m2m_100": ["M2M100Config"],
"models.mamba": ["MambaConfig"], "models.mamba": ["MambaConfig"],
"models.mamba2": ["Mamba2Config"],
"models.marian": ["MarianConfig"], "models.marian": ["MarianConfig"],
"models.markuplm": [ "models.markuplm": [
"MarkupLMConfig", "MarkupLMConfig",
...@@ -2550,6 +2551,13 @@ else: ...@@ -2550,6 +2551,13 @@ else:
"MambaPreTrainedModel", "MambaPreTrainedModel",
] ]
) )
_import_structure["models.mamba2"].extend(
[
"Mamba2ForCausalLM",
"Mamba2Model",
"Mamba2PreTrainedModel",
]
)
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"]) _import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
_import_structure["models.markuplm"].extend( _import_structure["models.markuplm"].extend(
[ [
...@@ -5240,6 +5248,7 @@ if TYPE_CHECKING: ...@@ -5240,6 +5248,7 @@ if TYPE_CHECKING:
) )
from .models.m2m_100 import M2M100Config from .models.m2m_100 import M2M100Config
from .models.mamba import MambaConfig from .models.mamba import MambaConfig
from .models.mamba2 import Mamba2Config
from .models.marian import MarianConfig from .models.marian import MarianConfig
from .models.markuplm import ( from .models.markuplm import (
MarkupLMConfig, MarkupLMConfig,
...@@ -7046,6 +7055,11 @@ if TYPE_CHECKING: ...@@ -7046,6 +7055,11 @@ if TYPE_CHECKING:
MambaModel, MambaModel,
MambaPreTrainedModel, MambaPreTrainedModel,
) )
from .models.mamba2 import (
Mamba2ForCausalLM,
Mamba2Model,
Mamba2PreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.markuplm import ( from .models.markuplm import (
MarkupLMForQuestionAnswering, MarkupLMForQuestionAnswering,
......
...@@ -135,6 +135,7 @@ from . import ( ...@@ -135,6 +135,7 @@ from . import (
lxmert, lxmert,
m2m_100, m2m_100,
mamba, mamba,
mamba2,
marian, marian,
markuplm, markuplm,
mask2former, mask2former,
......
...@@ -152,6 +152,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -152,6 +152,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("lxmert", "LxmertConfig"), ("lxmert", "LxmertConfig"),
("m2m_100", "M2M100Config"), ("m2m_100", "M2M100Config"),
("mamba", "MambaConfig"), ("mamba", "MambaConfig"),
("mamba2", "Mamba2Config"),
("marian", "MarianConfig"), ("marian", "MarianConfig"),
("markuplm", "MarkupLMConfig"), ("markuplm", "MarkupLMConfig"),
("mask2former", "Mask2FormerConfig"), ("mask2former", "Mask2FormerConfig"),
...@@ -440,6 +441,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -440,6 +441,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("m2m_100", "M2M100"), ("m2m_100", "M2M100"),
("madlad-400", "MADLAD-400"), ("madlad-400", "MADLAD-400"),
("mamba", "Mamba"), ("mamba", "Mamba"),
("mamba2", "mamba2"),
("marian", "Marian"), ("marian", "Marian"),
("markuplm", "MarkupLM"), ("markuplm", "MarkupLM"),
("mask2former", "Mask2Former"), ("mask2former", "Mask2Former"),
......
...@@ -144,6 +144,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -144,6 +144,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("lxmert", "LxmertModel"), ("lxmert", "LxmertModel"),
("m2m_100", "M2M100Model"), ("m2m_100", "M2M100Model"),
("mamba", "MambaModel"), ("mamba", "MambaModel"),
("mamba2", "Mamba2Model"),
("marian", "MarianModel"), ("marian", "MarianModel"),
("markuplm", "MarkupLMModel"), ("markuplm", "MarkupLMModel"),
("mask2former", "Mask2FormerModel"), ("mask2former", "Mask2FormerModel"),
...@@ -310,6 +311,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -310,6 +311,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("luke", "LukeForMaskedLM"), ("luke", "LukeForMaskedLM"),
("lxmert", "LxmertForPreTraining"), ("lxmert", "LxmertForPreTraining"),
("mamba", "MambaForCausalLM"), ("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("mega", "MegaForMaskedLM"), ("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForPreTraining"), ("megatron-bert", "MegatronBertForPreTraining"),
("mobilebert", "MobileBertForPreTraining"), ("mobilebert", "MobileBertForPreTraining"),
...@@ -394,6 +396,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -394,6 +396,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("luke", "LukeForMaskedLM"), ("luke", "LukeForMaskedLM"),
("m2m_100", "M2M100ForConditionalGeneration"), ("m2m_100", "M2M100ForConditionalGeneration"),
("mamba", "MambaForCausalLM"), ("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("marian", "MarianMTModel"), ("marian", "MarianMTModel"),
("mega", "MegaForMaskedLM"), ("mega", "MegaForMaskedLM"),
("megatron-bert", "MegatronBertForCausalLM"), ("megatron-bert", "MegatronBertForCausalLM"),
...@@ -472,6 +475,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -472,6 +475,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("jetmoe", "JetMoeForCausalLM"), ("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"), ("llama", "LlamaForCausalLM"),
("mamba", "MambaForCausalLM"), ("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("marian", "MarianForCausalLM"), ("marian", "MarianForCausalLM"),
("mbart", "MBartForCausalLM"), ("mbart", "MBartForCausalLM"),
("mega", "MegaForCausalLM"), ("mega", "MegaForCausalLM"),
......
...@@ -270,6 +270,7 @@ else: ...@@ -270,6 +270,7 @@ else:
("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)),
("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)),
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
( (
"mbart", "mbart",
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_mamba2": ["Mamba2Config", "Mamba2OnnxConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mamba2"] = [
"Mamba2ForCausalLM",
"Mamba2Model",
"Mamba2PreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_mamba2 import Mamba2Config, Mamba2OnnxConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mamba2 import (
Mamba2ForCausalLM,
Mamba2Model,
Mamba2PreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MAMBA2 configuration"""
import math
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class Mamba2Config(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`Mamba2Model`]. It is used to instantiate a MAMBA2
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the MAMBA2
[state-spaces/mamba2-2.8b](https://huggingface.co/state-spaces/mamba2-2.8b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_heads (`int`, *optional*, defaults to 128):
Number of heads for the evolution matrices of mamba 2.
head_dim (`int`, *optional*, defaults to 64):
Dimension of each head.
vocab_size (`int`, *optional*, defaults to 32768):
Vocabulary size of the MAMBA2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Mamba2Model`].
hidden_size (`int`, *optional*, defaults to 4096):
Dimensionality of the embeddings and hidden states.
state_size (`int`, *optional*, defaults to 128): shape of the state space latents.
num_hidden_layers (`int`, *optional*, defaults to 64):
Number of hidden layers in the model.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
The epsilon to use in the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 1):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
The id of the beginning of sentence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the end of sentence token in the vocabulary.
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
n_groups (`int`, *optional*, defaults to 8):
Number of groups for the evolution matrices of mamba 2.
use_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.1):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
time_step_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
Accepted range of time step values.
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
norm_before_gate (`bool`, *optional*, defaults to `True`):
Option of cuda kernels -whether to normalize before the gate or not.
rms_norm (`bool`, *optional*, defaults to `True`):
Whether to use RMS norm or not.
chunk_size (`int`, *optional*, defaults to 256):
Size of the chunks that will comprise the sequence.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie word embeddings or not.
Example:
```python
>>> from transformers import Mamba2Config, Mamba2Model
>>> # Initializing a Mamba2 configuration
>>> configuration = Mamba2Config()
>>> # Initializing a model (with random weights) from the configuration
>>> model = Mamba2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "mamba2"
def __init__(
self,
num_heads=128,
head_dim=64,
vocab_size=32768,
hidden_size=4096,
state_size=128,
num_hidden_layers=64,
layer_norm_epsilon=1e-5,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
expand=2,
conv_kernel=4,
n_groups=8,
use_bias=False,
use_conv_bias=True,
hidden_act="silu",
initializer_range=0.1,
residual_in_fp32=True,
time_step_rank="auto",
time_step_min=0.001,
time_step_max=0.1,
time_step_floor=1e-4,
time_step_limit=(0.0, float("inf")),
rescale_prenorm_residual=False,
use_cache=True,
norm_before_gate=True,
rms_norm=True,
chunk_size=256,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
self.num_hidden_layers = num_hidden_layers
self.layer_norm_epsilon = layer_norm_epsilon
self.conv_kernel = conv_kernel
self.expand = expand
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.use_bias = use_bias
self.use_conv_bias = use_conv_bias
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
self.time_step_min = time_step_min
self.time_step_max = time_step_max
self.time_step_floor = time_step_floor
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.n_groups = n_groups
self.num_heads = num_heads
self.head_dim = head_dim
self.norm_before_gate = norm_before_gate
self.rms_norm = rms_norm
self.state_size = state_size
self.chunk_size = chunk_size
self.time_step_limit = time_step_limit
self.tie_word_embeddings = tie_word_embeddings
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# coding=utf-8
# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This script can be used to convert checkpoints provided in the `mamba2_ssm` library into the format provided in HuggingFace `transformers`. It depends on the `mamba2_ssm` package to be installed."""
import argparse
import torch
from safetensors import safe_open
from transformers import LlamaTokenizerFast, Mamba2Config, Mamba2ForCausalLM
def convert_mamba2_checkpoint_file_to_huggingface_model_file(
mamba2_checkpoint_path: str, tokenizer_model_path: str, output_dir: str
) -> None:
hf_config = Mamba2Config()
hf_model = Mamba2ForCausalLM(hf_config)
# Load weights and config from paths
original_state_dict = {}
with safe_open(mamba2_checkpoint_path, framework="pt") as f:
for k in f.keys():
newk = k.removeprefix("model.")
original_state_dict[newk] = f.get_tensor(k).clone()
hf_model.load_state_dict(original_state_dict)
# Save new model to pytorch_dump_path
hf_model.to(torch.bfloat16).save_pretrained(output_dir)
tokenizer_class = LlamaTokenizerFast
tokenizer = tokenizer_class(tokenizer_model_path, legacy=False, from_slow=True)
tokenizer.save_pretrained(output_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--mamba2_checkpoint_file",
type=str,
required=True,
help="Path to a `pytorch_model.bin` mamba2_ssm checkpoint file to be converted.",
)
parser.add_argument(
"-c",
"--tokenizer_model_path",
type=str,
required=True,
help="Path to a `config.json` file corresponding to a Mamba2Config of the original mamba2_ssm model.",
)
parser.add_argument(
"-o", "--output_dir", type=str, required=True, help="Path to directory to save the converted output model to."
)
args = parser.parse_args()
convert_mamba2_checkpoint_file_to_huggingface_model_file(
args.mamba2_checkpoint_file, args.tokenizer_model_path, args.output_dir
)
# coding=utf-8
# Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch MAMBA2 model."""
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_mamba2 import Mamba2Config
logger = logging.get_logger(__name__)
if is_mamba_2_ssm_available():
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
else:
selective_state_update = None
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
_CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1"
_CONFIG_FOR_DOC = "Mamba2Config"
# Helper methods for segment sum computation
def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
"""
Padding x tensor with `pad_size` on the seq_len dim (dim=1)
Assumes that we only have tensors of either size 4 or 3
"""
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
def reshape_into_chunks(input_tensor, pad_size, chunk_size):
"""
Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
simultaneously splitting it into chunk sequences.
Assumes that we only have tensors of either size 4 or 3
"""
# [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
input_tensor = pad_tensor_by_size(input_tensor, pad_size)
if len(input_tensor.shape) == 3:
# [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
else:
# [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
return input_tensor.reshape(
input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
)
def segment_sum(input_tensor):
"""
More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
"""
chunk_size = input_tensor.size(-1)
# 1. expand input tensor to have an additional dimension and repeat along that dimension
# [..., chunk_size] -> [..., chunk_size, chunk_size]
input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
# 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
input_tensor = input_tensor.masked_fill(~mask, 0)
# 3. compute actual cumsum
tensor_segsum = torch.cumsum(input_tensor, dim=-2)
# 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
return tensor_segsum
class Mamba2Cache:
"""
Arguments:
config: Mamba2Config
batch_size: int
dtype: torch.dtype
device: torch.device
Attributes:
seqlen_offset: int
dtype: torch.dtype
conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
"""
def __init__(
self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
):
self.seqlen_offset = 0
self.dtype = dtype
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = int(config.expand * config.hidden_size)
self.conv_states = {
i: torch.zeros(
batch_size,
self.intermediate_size + 2 * config.n_groups * config.state_size,
self.conv_kernel_size,
device=device,
dtype=dtype,
)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(
batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
)
for i in range(config.num_hidden_layers)
}
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
) -> torch.Tensor:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
class MambaRMSNormGated(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Mamba2Mixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
"""
def __init__(self, config: Mamba2Config, layer_idx: int):
super().__init__()
self.num_heads = config.num_heads
self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.intermediate_size = int(config.expand * self.hidden_size)
self.time_step_rank = int(config.time_step_rank)
self.layer_idx = layer_idx
self.use_conv_bias = config.use_conv_bias
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.norm_before_gate = config.norm_before_gate
self.layer_norm_epsilon = config.layer_norm_epsilon
self.rms_norm = config.rms_norm
self.n_groups = config.n_groups
self.head_dim = config.head_dim
self.chunk_size = config.chunk_size
self.time_step_limit = config.time_step_limit
self.time_step_min = config.time_step_min
self.time_step_max = config.time_step_max
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=config.use_conv_bias,
kernel_size=config.conv_kernel,
groups=self.conv_dim,
padding=config.conv_kernel - 1,
)
# projection of the input hidden states
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
projection_size,
bias=config.use_bias,
)
# selective projection used to make dt, B and C input dependant
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
self.A_log = nn.Parameter(torch.log(A))
self.A_log._no_weight_decay = True
self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
self.D = nn.Parameter(torch.ones(self.num_heads))
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
# set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
# getting projected states from cache if it exists
if cache_params is not None and cache_params.seqlen_offset > 0:
in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
_, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
cache_params.conv_states[self.layer_idx],
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
A = -torch.exp(self.A_log.float()) # (nheads,)
A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
hidden_states = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states_reshaped,
dt,
A,
B,
C,
D,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
)
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
hidden_states = self.norm(hidden_states, gate)
out = self.out_proj(hidden_states)[:, None, ...]
# if no cache is found, calling the kernel
else:
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
if self.training and cache_params is None:
out, ssm_state = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=self.norm_before_gate,
return_final_states=True,
**dt_limit_kwargs,
)
else:
gate, hidden_states_B_C, time_step = torch.split(
projected_states,
[self.intermediate_size, self.conv_dim, self.num_heads],
dim=-1,
)
time_step = nn.functional.softplus(time_step + self.dt_bias)
# 1D Convolution
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
hidden_states_B_C = self.act(
self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
) # (B, L, self.d_inner + 2 * ngroups * d_state)
else:
hidden_states_B_C = causal_conv1d_fn(
x=hidden_states_B_C.transpose(1, 2),
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
).transpose(1, 2)[:, :seq_len]
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
scan_output, ssm_state = mamba_chunk_scan_combined(
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
time_step,
A,
B.view(batch_size, seq_len, self.n_groups, -1),
C.view(batch_size, seq_len, self.n_groups, -1),
chunk_size=self.chunk_size,
D=self.D,
z=None,
seq_idx=None,
return_final_states=True,
**dt_limit_kwargs,
)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = scan_output.view(batch_size, seq_len, -1)
# Multiply "gate" branch and apply extra normalization layer
scan_output = self.norm(scan_output, gate)
out = self.out_proj(scan_output)
return out
# fmt: off
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1))
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
_, _, gate, hidden_states, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
# Convolution sequence transformation
if cache_params is not None:
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
ssm_state = ssm_state.to(hidden_states.device)
if cache_params.seqlen_offset > 0:
conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
# handle batched generation - states are copied through
conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
else:
hidden_states = hidden_states.transpose(1,2)
conv_state = nn.functional.pad(
hidden_states,
(self.conv_kernel_size - hidden_states.shape[-1], 0)
)
cache_params.conv_states[self.layer_idx].copy_(conv_state)
hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
else:
ssm_state = torch.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
)
hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
A = -torch.exp(self.A_log.float()) # [num_heads]
if cache_params is not None and cache_params.seqlen_offset > 0:
# Note: there is no need to pad parameter matrices here, as there is just one new token
# for batched generation
dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
# [num_heads] -> [num_heads, head_dim]
dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
# [bsz, num_heads, head_dim, state_size]
dA = torch.exp(dt[..., None] * A)
# Discretize B
# [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
# -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
B = B.reshape(batch_size, -1, B.shape[-1])
# [bsz, num_heads, head_dim, state_size]
dB = dt[..., None] * B[..., None, :]
# Discretize x into dB
# [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
dBx = dB * hidden_states[..., None]
# State calculation
cache_params.ssm_states[self.layer_idx].copy_(
cache_params.ssm_states[self.layer_idx] * dA + dBx
)
# Subsequent output
# [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
C = C.reshape(batch_size, -1, C.shape[-1])
# [bsz, num_heads, head_dim]
ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
# Reshape ssm_states to merge the first two dimensions
ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
y = torch.bmm(ssm_states_reshaped, C_reshaped)
y = y.view(batch_size, self.num_heads, self.head_dim)
# D skip connection
# [num_heads] -> [num_heads, head_dim]
D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
y = (y + hidden_states * D).to(y.dtype)
# [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
y = y.reshape(batch_size, -1)[:, None, ...]
else:
# begin ssd naive implementation without einsums
dt = nn.functional.softplus(dt + self.dt_bias)
dt = torch.clamp(dt, self.time_step_min)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
# Discretize x and A
hidden_states = hidden_states * dt[..., None]
A = A.to(hidden_states.dtype) * dt
# Rearrange into blocks/chunks
hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
# [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
A = A.permute(0, 3, 1, 2)
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))
# First, contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
# Step 2: Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)
# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
if cache_params is not None and cache_params.seqlen_offset > 0:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
states, ssm_state = new_states[:, :-1], new_states[:, -1]
# Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
y = y + D_residual
# Cutting off padded chunks
if pad_size > 0:
y = y[:, :seq_len, :, :]
y = y.reshape(batch_size, seq_len, -1)
if ssm_state is not None and cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
scan_output = self.norm(y, gate)
# end ssd naive
# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
return contextualized_states
# fmt: on
def forward(
self,
hidden_states,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
class Mamba2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Mamba2Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.residual_in_fp32 = config.residual_in_fp32
self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
def forward(
self,
hidden_states,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
residual = hidden_states
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
)
hidden_states = residual + hidden_states
return hidden_states
class Mamba2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Mamba2Config
base_model_prefix = "backbone"
_no_split_modules = ["Mamba2Block"]
supports_gradient_checkpointing = True
_is_stateful = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, Mamba2Mixer):
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True
dt = torch.exp(
torch.rand(self.config.num_heads)
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ math.log(self.config.time_step_min)
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
module.dt_bias.copy_(inv_dt)
module.dt_bias._no_reinit = True
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.initializer_range)
if self.config.rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(self.config.num_hidden_layers)
@dataclass
# Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
class Mamba2Output(ModelOutput):
"""
Class for the MAMBA2 model outputs.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
cache_params (`Mamba2Cache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
last_hidden_state: Optional[torch.FloatTensor] = None
cache_params: Optional[Mamba2Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
# Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
class Mamba2CausalLMOutput(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
cache_params (`Mamba2Cache`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
Includes both the State space model state matrices after the selective scan, and the Convolutional states
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
cache_params: Optional[Mamba2Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
MAMBA2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Mamba2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
MAMBA2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
Indices of input sequence tokens in the vocabulary.
If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
cache_params (`Mamba2Cache`, *optional*):
If passed along, the model uses the previous state in all the blocks (which will give the output for the
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
use_cache (`bool`, *optional*):
If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
"The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.",
MAMBA2_START_DOCSTRING,
)
class Mamba2Model(Mamba2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self._register_load_state_dict_pre_hook(self.load_hook)
self.post_init()
def load_hook(self, state_dict, prefix, *args):
for k in state_dict:
if "embedding." in k:
state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
break
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, new_embeddings):
self.embeddings = new_embeddings
@add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Mamba2Output,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
cache_params: Optional[Mamba2Cache] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple, Mamba2Output]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if use_cache:
if cache_params is None:
cache_params = Mamba2Cache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
elif cache_position is None:
# cases when we do manual forward instead of using `model.generate` which will initiate
# `cache_position` and makes sure it is not None, throw error here instead of doing some
# hack to conjecture the current cache position
raise ValueError(
"You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
"you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
"be initialized for you automatically"
)
else:
cache_params = None
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
for mixer_block in self.layers:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
)
else:
hidden_states = mixer_block(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=attention_mask,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if use_cache:
cache_params.seqlen_offset += inputs_embeds.shape[1]
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
return Mamba2Output(
last_hidden_state=hidden_states,
cache_params=cache_params if use_cache else None,
hidden_states=all_hidden_states,
)
@add_start_docstrings(
"""
The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
embeddings).
""",
MAMBA2_START_DOCSTRING,
)
class Mamba2ForCausalLM(Mamba2PreTrainedModel):
_tied_weights_keys = []
def __init__(self, config):
super().__init__(config)
self.backbone = Mamba2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_input_embeddings(self):
return self.backbone.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
return self.backbone.set_input_embeddings(new_embeddings)
def prepare_inputs_for_generation(
self,
input_ids,
inputs_embeds=None,
use_cache=None,
cache_params: Optional[Mamba2Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
if input_ids.shape[1] == 0:
past_len = inputs_embeds.shape[1]
else:
past_len = input_ids.shape[1]
if use_cache:
# `cache_position` should have been initialized in `generate`
if cache_position is None:
raise ValueError(
"`cache_position` should not be None as it should have been initialized in "
"`model.generate`, you are responsible for passing in a valid `cache_position` if "
"you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
)
# how do we detect that we are in decoding without cache?
if cache_position[0] > 0:
input_ids = input_ids[:, -1][..., None]
attention_mask = attention_mask[:, -1][..., None]
else:
# we initialize the `cache_position` to full size of `conv_states` at prefill stage
# considering padding will be applied when input length is shorter, and truncation
# will be applied when it is longer, so it will be equivalent to always have it match
# the length of `cache_params.conv_states`, which is `config.conv_kernel`
cache_position = torch.arange(0, past_len, device=input_ids.device)
# if the cache is not used, we also do have to extend the attention mask here
# TODO there is likely a cleverer way to do this
extended_mask = torch.ones(
attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
)
attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
cache_params = None
if attention_mask.shape[1] < past_len:
# we have to update manually the attention mask if
# we are in decoding without cache
# and we don't have position_ids here
# TODO but we should be able to use cache_position though at a later time
extended_mask = torch.ones(
attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
)
attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
if inputs_embeds is not None and cache_params is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"cache_params": cache_params,
"use_cache": use_cache,
"cache_position": cache_position,
}
)
return model_inputs
@add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Mamba2CausalLMOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[Mamba2Cache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, Mamba2CausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
mamba2_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
attention_mask=attention_mask,
)
hidden_states = mamba2_outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
output = (logits,) + mamba2_outputs[1:]
return ((loss,) + output) if loss is not None else output
return Mamba2CausalLMOutput(
loss=loss,
logits=logits,
cache_params=mamba2_outputs.cache_params,
hidden_states=mamba2_outputs.hidden_states,
)
...@@ -5542,6 +5542,27 @@ class MambaPreTrainedModel(metaclass=DummyObject): ...@@ -5542,6 +5542,27 @@ class MambaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class Mamba2ForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Mamba2Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Mamba2PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MarianForCausalLM(metaclass=DummyObject): class MarianForCausalLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -385,6 +385,21 @@ def is_mamba_ssm_available(): ...@@ -385,6 +385,21 @@ def is_mamba_ssm_available():
return False return False
def is_mamba_2_ssm_available():
if is_torch_available():
import torch
if not torch.cuda.is_available():
return False
else:
if _is_package_available("mamba_ssm"):
import mamba_ssm
if version.parse(mamba_ssm.__version__) >= version.parse("2.0.4"):
return True
return False
def is_causal_conv1d_available(): def is_causal_conv1d_available():
if is_torch_available(): if is_torch_available():
import torch import torch
......
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from typing import Dict, List, Tuple
from parameterized import parameterized
from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from transformers import (
Mamba2ForCausalLM,
Mamba2Model,
)
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
else:
is_torch_greater_or_equal_than_2_0 = False
class Mamba2ModelTester:
def __init__(
self,
parent,
batch_size=14,
num_heads=8,
n_groups=8,
state_size=2,
head_dim=8,
conv_kernel=4,
chunk_size=8,
seq_length=7,
is_training=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
hidden_act="silu",
hidden_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_sequence_label_size=2,
num_labels=3,
num_choices=4,
scope=None,
tie_word_embeddings=False,
):
self.parent = parent
self.num_heads = num_heads
self.n_groups = n_groups
self.head_dim = head_dim
self.state_size = state_size
self.conv_kernel = conv_kernel
self.chunk_size = chunk_size
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1
self.tie_word_embeddings = tie_word_embeddings
def get_large_model_config(self):
return Mamba2Config.from_pretrained("revision='refs/pr/9'")
def prepare_config_and_inputs(
self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = self.get_config(
gradient_checkpointing=gradient_checkpointing,
)
return (
config,
input_ids,
None,
sequence_labels,
token_labels,
choice_labels,
)
def get_config(self, gradient_checkpointing=False):
return Mamba2Config(
head_dim=self.head_dim,
num_heads=self.num_heads,
n_groups=self.n_groups,
state_size=self.state_size,
conv_kernel=self.conv_kernel,
chunk_size=self.chunk_size,
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
activation_function=self.hidden_act,
n_positions=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
use_cache=True,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
gradient_checkpointing=gradient_checkpointing,
tie_word_embeddings=self.tie_word_embeddings,
)
def prepare_config_and_inputs_for_common(self):
(
config,
input_ids,
_,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict
@unittest.skipIf(
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
)
@require_torch
class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Mamba2Model, Mamba2ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (Mamba2ForCausalLM,) if is_torch_available() else ()
has_attentions = False # Mamba does not support attentions
fx_compatible = False # FIXME let's try to support this @molbap
test_torchscript = False # FIXME I think this should be doable @molbap @ArthurZucker
test_missing_keys = False
test_model_parallel = False
test_pruning = False
test_head_masking = False # Mamba does not have attention heads
pipeline_model_mapping = (
{"feature-extraction": Mamba2Model, "text-generation": Mamba2ForCausalLM} if is_torch_available() else {}
)
def setUp(self):
self.model_tester = Mamba2ModelTester(self)
self.config_tester = ConfigTester(
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)
def test_initialization(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config=config)
for name, param in model.named_parameters():
if "D" in name:
if param.requires_grad:
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
@unittest.skip(reason="Mamba 2 weights are not tied")
def test_tied_weights_keys(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
def test_beam_search_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing is interacting with beam search")
def test_beam_sample_generate(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_generate_without_input_ids(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
@unittest.skip(reason="Initialization of mamba2 fails this")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self):
pass
@unittest.skip(reason="To fix, Mamba 2 cache slicing test case is an edge case")
def test_generate_from_inputs_embeds_decoder_only(self):
pass
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, Mamba2Cache): # MODIFIED PART START
recursive_check(tuple_object.conv_states, dict_object.conv_states)
recursive_check(tuple_object.ssm_states, dict_object.ssm_states)
elif isinstance(tuple_object, (List, Tuple)): # MODIFIED PART END
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(tuple_object, dict_object, atol=1e-5),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
@require_torch
@slow
class Mamba2IntegrationTest(unittest.TestCase):
def setUp(self):
self.model_id = "mistralai/Mamba-Codestral-7B-v0.1"
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_id, revision="refs/pr/9", from_slow=True, legacy=False
)
self.prompt = ("[INST]Write a hello world program in C++.",)
@parameterized.expand(
[
(torch_device,),
]
)
@slow
@require_torch
def test_simple_generate(self, device):
"""
Simple generate test to avoid regressions.
Note: state-spaces (cuda) implementation and pure torch implementation
have irreconciliable differences as of now, which will cause this test to fail
in an environment with state-spaces installed.
"""
tokenizer = self.tokenizer
tokenizer.pad_token_id = tokenizer.eos_token_id
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16)
model.to(device)
input_ids = tokenizer("[INST]Write a hello world program in C++.[/INST]", return_tensors="pt")["input_ids"].to(
device
)
out = model.generate(input_ids, do_sample=False, use_cache=True, max_new_tokens=30)
output_sentence = tokenizer.decode(out[0])
ground_truth_sentence = """<s>[INST]Write a hello world program in C++.[/INST] Sure, here is a simple "Hello, World!" program in C++:\n\n```cpp\n#include <iostream>\n\n"""
self.assertEqual(output_sentence, ground_truth_sentence)
@slow
@require_torch_gpu
def test_batched_equivalence_with_cache(self):
"""
Verifies that batched generation matches individual generation.
Important because of the specific caching mechanism + statefulness of mamba model.
Depending on precision and devices, differences can be observed from generation to generation.
"""
tokenizer = self.tokenizer
prompt = [
"[INST]Write C#.[/INST]",
"[INST]Write a hello world in C++.[/INST]",
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
]
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer.pad_token_id = tokenizer.eos_token_id
# batched generation
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True)
batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True)
# individual generation
for index_gen, individual_prompt in enumerate(prompt):
inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device)
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
@slow
@require_torch_gpu
def test_batched_equivalence_without_cache(self):
"""
Verifies that batched generation matches individual generation without cache.
Important because of the specific caching mechanism + statefulness of mamba model.
Depending on precision and devices, differences can be observed from generation to generation.
"""
tokenizer = self.tokenizer
prompt = [
"[INST]Write C#.[/INST]",
"[INST]Write a hello world in C++.[/INST]",
"[INST] Write a simple Fibonacci number computation function in Rust that does memoization, with comments, in safe Rust.[/INST]",
]
model = Mamba2ForCausalLM.from_pretrained(self.model_id, revision="refs/pr/9", torch_dtype=torch.bfloat16).to(
torch_device
)
tokenizer.pad_token_id = tokenizer.eos_token_id
# batched generation
tokenized_prompts = tokenizer(prompt, return_tensors="pt", padding="longest").to(torch_device)
batched_gen = model.generate(**tokenized_prompts, max_new_tokens=30, use_cache=True)
batched_output = tokenizer.batch_decode(batched_gen, skip_special_tokens=True)
# individual generation
for index_gen, individual_prompt in enumerate(prompt):
inputs = tokenizer(individual_prompt, return_tensors="pt", padding="longest").to(torch_device)
individual_gen = model.generate(**inputs, max_new_tokens=30, use_cache=True)
individual_output = tokenizer.batch_decode(individual_gen, skip_special_tokens=True)[0]
self.assertEqual(individual_output[:100], batched_output[index_gen][:100])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment