Unverified Commit 7c114912 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add new model (#32615)



* v1 - working version

* fix

* fix

* fix

* fix

* rename to correct name

* fix title

* fixup

* rename files

* fix

* add copied from on tests

* rename to `FalconMamba` everywhere and fix bugs

* fix quantization + accelerate

* fix copies

* add `torch.compile` support

* fix tests

* fix tests and add slow tests

* copies on config

* merge the latest changes

* fix tests

* add few lines about instruct

* Apply suggestions from code review
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix tests

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 48101cf8
...@@ -370,6 +370,8 @@ ...@@ -370,6 +370,8 @@
title: ESM title: ESM
- local: model_doc/falcon - local: model_doc/falcon
title: Falcon title: Falcon
- local: model_doc/falcon_mamba
title: FalconMamba
- local: model_doc/fastspeech2_conformer - local: model_doc/fastspeech2_conformer
title: FastSpeech2Conformer title: FastSpeech2Conformer
- local: model_doc/flan-t5 - local: model_doc/flan-t5
......
...@@ -136,6 +136,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -136,6 +136,7 @@ Flax), PyTorch, and/or TensorFlow.
| [ESM](model_doc/esm) | ✅ | ✅ | ❌ | | [ESM](model_doc/esm) | ✅ | ✅ | ❌ |
| [FairSeq Machine-Translation](model_doc/fsmt) | ✅ | ❌ | ❌ | | [FairSeq Machine-Translation](model_doc/fsmt) | ✅ | ❌ | ❌ |
| [Falcon](model_doc/falcon) | ✅ | ❌ | ❌ | | [Falcon](model_doc/falcon) | ✅ | ❌ | ❌ |
| [FalconMamba](model_doc/falcon_mamba) | ✅ | ❌ | ❌ |
| [FastSpeech2Conformer](model_doc/fastspeech2_conformer) | ✅ | ❌ | ❌ | | [FastSpeech2Conformer](model_doc/fastspeech2_conformer) | ✅ | ❌ | ❌ |
| [FLAN-T5](model_doc/flan-t5) | ✅ | ✅ | ✅ | | [FLAN-T5](model_doc/flan-t5) | ✅ | ✅ | ✅ |
| [FLAN-UL2](model_doc/flan-ul2) | ✅ | ✅ | ✅ | | [FLAN-UL2](model_doc/flan-ul2) | ✅ | ✅ | ✅ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FalconMamba
## Overview
The FalconMamba model was proposed by TII UAE (Technology Innovation Institute) in their release.
The abstract from the paper is the following:
*We present FalconMamba, a new base large language model based on the novel Mamba architecture. FalconMamba is trained on 5.8 trillion tokens with carefully selected data mixtures. As a pure Mamba-based model, FalconMamba surpasses leading open-weight models based on Transformers, such as Mistral 7B, Llama3 8B, and Falcon2 11B. It is on par with Gemma 7B and outperforms models with different architecture designs, such as RecurrentGemma 9B. Currently, FalconMamba is the best-performing Mamba model in the literature at this scale, surpassing both existing Mamba and hybrid Mamba-Transformer models.
Due to its architecture, FalconMamba is significantly faster at inference and requires substantially less memory for long sequence generation. Despite recent studies suggesting that hybrid Mamba-Transformer models outperform pure architecture designs, we argue and demonstrate that the pure Mamba design can achieve similar, even superior results compared to the hybrid design. We make the weights of our implementation of FalconMamba publicly available under a permissive license.*
Tips:
- FalconMamba is mostly based on Mamba architecutre, the same [tips and best practices](./mamba) would be relevant here.
The model has been trained on approximtely 6T tokens consisting a mixture of many data sources such as RefineWeb, Cosmopedia and Math data.
For more details about the training procedure and the architecture, have a look at [the technical paper of FalconMamba]() (coming soon).
# Usage
Below we demonstrate how to use the model:
```python
from transformers import FalconMambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
The architecture is also compatible with `torch.compile` for faster generation:
```python
from transformers import FalconMambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", torch_dtype=torch.bfloat16).to(0)
model = torch.compile(model)
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
If you have access to a GPU that is compatible with `bitsandbytes`, you can also quantize the model in 4-bit precision:
```python
from transformers import FalconMambaForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b")
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", quantization_config=quantization_config)
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
```
You can also play with the instruction fine-tuned model:
```python
from transformers import FalconMambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b-instruct")
model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b-instruct")
# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).input_ids
outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
```
## FalconMambaConfig
[[autodoc]] FalconMambaConfig
## FalconMambaModel
[[autodoc]] FalconMambaModel
- forward
## FalconMambaLMHeadModel
[[autodoc]] FalconMambaForCausalLM
- forward
...@@ -416,6 +416,7 @@ _import_structure = { ...@@ -416,6 +416,7 @@ _import_structure = {
"models.ernie": ["ErnieConfig"], "models.ernie": ["ErnieConfig"],
"models.esm": ["EsmConfig", "EsmTokenizer"], "models.esm": ["EsmConfig", "EsmTokenizer"],
"models.falcon": ["FalconConfig"], "models.falcon": ["FalconConfig"],
"models.falcon_mamba": ["FalconMambaConfig"],
"models.fastspeech2_conformer": [ "models.fastspeech2_conformer": [
"FastSpeech2ConformerConfig", "FastSpeech2ConformerConfig",
"FastSpeech2ConformerHifiGanConfig", "FastSpeech2ConformerHifiGanConfig",
...@@ -2138,6 +2139,13 @@ else: ...@@ -2138,6 +2139,13 @@ else:
"FalconPreTrainedModel", "FalconPreTrainedModel",
] ]
) )
_import_structure["models.falcon_mamba"].extend(
[
"FalconMambaForCausalLM",
"FalconMambaModel",
"FalconMambaPreTrainedModel",
]
)
_import_structure["models.fastspeech2_conformer"].extend( _import_structure["models.fastspeech2_conformer"].extend(
[ [
"FastSpeech2ConformerHifiGan", "FastSpeech2ConformerHifiGan",
...@@ -5127,6 +5135,7 @@ if TYPE_CHECKING: ...@@ -5127,6 +5135,7 @@ if TYPE_CHECKING:
from .models.ernie import ErnieConfig from .models.ernie import ErnieConfig
from .models.esm import EsmConfig, EsmTokenizer from .models.esm import EsmConfig, EsmTokenizer
from .models.falcon import FalconConfig from .models.falcon import FalconConfig
from .models.falcon_mamba import FalconMambaConfig
from .models.fastspeech2_conformer import ( from .models.fastspeech2_conformer import (
FastSpeech2ConformerConfig, FastSpeech2ConformerConfig,
FastSpeech2ConformerHifiGanConfig, FastSpeech2ConformerHifiGanConfig,
...@@ -6739,6 +6748,11 @@ if TYPE_CHECKING: ...@@ -6739,6 +6748,11 @@ if TYPE_CHECKING:
FalconModel, FalconModel,
FalconPreTrainedModel, FalconPreTrainedModel,
) )
from .models.falcon_mamba import (
FalconMambaForCausalLM,
FalconMambaModel,
FalconMambaPreTrainedModel,
)
from .models.fastspeech2_conformer import ( from .models.fastspeech2_conformer import (
FastSpeech2ConformerHifiGan, FastSpeech2ConformerHifiGan,
FastSpeech2ConformerModel, FastSpeech2ConformerModel,
......
...@@ -84,6 +84,7 @@ from . import ( ...@@ -84,6 +84,7 @@ from . import (
ernie, ernie,
esm, esm,
falcon, falcon,
falcon_mamba,
fastspeech2_conformer, fastspeech2_conformer,
flaubert, flaubert,
flava, flava,
......
...@@ -100,6 +100,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -100,6 +100,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("ernie_m", "ErnieMConfig"), ("ernie_m", "ErnieMConfig"),
("esm", "EsmConfig"), ("esm", "EsmConfig"),
("falcon", "FalconConfig"), ("falcon", "FalconConfig"),
("falcon_mamba", "FalconMambaConfig"),
("fastspeech2_conformer", "FastSpeech2ConformerConfig"), ("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
("flaubert", "FlaubertConfig"), ("flaubert", "FlaubertConfig"),
("flava", "FlavaConfig"), ("flava", "FlavaConfig"),
...@@ -384,6 +385,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -384,6 +385,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("ernie_m", "ErnieM"), ("ernie_m", "ErnieM"),
("esm", "ESM"), ("esm", "ESM"),
("falcon", "Falcon"), ("falcon", "Falcon"),
("falcon_mamba", "FalconMamba"),
("fastspeech2_conformer", "FastSpeech2Conformer"), ("fastspeech2_conformer", "FastSpeech2Conformer"),
("flan-t5", "FLAN-T5"), ("flan-t5", "FLAN-T5"),
("flan-ul2", "FLAN-UL2"), ("flan-ul2", "FLAN-UL2"),
......
...@@ -98,6 +98,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -98,6 +98,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("ernie_m", "ErnieMModel"), ("ernie_m", "ErnieMModel"),
("esm", "EsmModel"), ("esm", "EsmModel"),
("falcon", "FalconModel"), ("falcon", "FalconModel"),
("falcon_mamba", "FalconMambaModel"),
("fastspeech2_conformer", "FastSpeech2ConformerModel"), ("fastspeech2_conformer", "FastSpeech2ConformerModel"),
("flaubert", "FlaubertModel"), ("flaubert", "FlaubertModel"),
("flava", "FlavaModel"), ("flava", "FlavaModel"),
...@@ -291,6 +292,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -291,6 +292,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("distilbert", "DistilBertForMaskedLM"), ("distilbert", "DistilBertForMaskedLM"),
("electra", "ElectraForPreTraining"), ("electra", "ElectraForPreTraining"),
("ernie", "ErnieForPreTraining"), ("ernie", "ErnieForPreTraining"),
("falcon_mamba", "FalconMambaForCausalLM"),
("flaubert", "FlaubertWithLMHeadModel"), ("flaubert", "FlaubertWithLMHeadModel"),
("flava", "FlavaForPreTraining"), ("flava", "FlavaForPreTraining"),
("fnet", "FNetForPreTraining"), ("fnet", "FNetForPreTraining"),
...@@ -377,6 +379,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( ...@@ -377,6 +379,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
("encoder-decoder", "EncoderDecoderModel"), ("encoder-decoder", "EncoderDecoderModel"),
("ernie", "ErnieForMaskedLM"), ("ernie", "ErnieForMaskedLM"),
("esm", "EsmForMaskedLM"), ("esm", "EsmForMaskedLM"),
("falcon_mamba", "FalconMambaForCausalLM"),
("flaubert", "FlaubertWithLMHeadModel"), ("flaubert", "FlaubertWithLMHeadModel"),
("fnet", "FNetForMaskedLM"), ("fnet", "FNetForMaskedLM"),
("fsmt", "FSMTForConditionalGeneration"), ("fsmt", "FSMTForConditionalGeneration"),
...@@ -462,6 +465,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -462,6 +465,7 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("electra", "ElectraForCausalLM"), ("electra", "ElectraForCausalLM"),
("ernie", "ErnieForCausalLM"), ("ernie", "ErnieForCausalLM"),
("falcon", "FalconForCausalLM"), ("falcon", "FalconForCausalLM"),
("falcon_mamba", "FalconMambaForCausalLM"),
("fuyu", "FuyuForCausalLM"), ("fuyu", "FuyuForCausalLM"),
("gemma", "GemmaForCausalLM"), ("gemma", "GemmaForCausalLM"),
("gemma2", "Gemma2ForCausalLM"), ("gemma2", "Gemma2ForCausalLM"),
......
...@@ -180,6 +180,7 @@ else: ...@@ -180,6 +180,7 @@ else:
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)), ("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
("esm", ("EsmTokenizer", None)), ("esm", ("EsmTokenizer", None)),
("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)), ("falcon", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("falcon_mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
( (
"fastspeech2_conformer", "fastspeech2_conformer",
("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None), ("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {
"configuration_falcon_mamba": ["FalconMambaConfig"],
}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_falcon_mamba"] = [
"FalconMambaForCausalLM",
"FalconMambaModel",
"FalconMambaPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_falcon_mamba import FalconMambaConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_falcon_mamba import (
FalconMambaForCausalLM,
FalconMambaModel,
FalconMambaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""FALCONMAMBA configuration"""
import math
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
# Copied from transformers.models.mamba.configuration_mamba.MambaConfig with mamba->falcon_mamba,Mamba->FalconMamba,MAMBA->FALCON_MAMBA,state-spaces/falcon_mamba-2.8b->tiiuae/falcon-mamba-7b,use_falcon_mambapy->use_mambapy
class FalconMambaConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a [`FalconMambaModel`]. It is used to instantiate a FALCON_MAMBA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the FALCON_MAMBA
[tiiuae/falcon-mamba-7b](https://huggingface.co/tiiuae/falcon-mamba-7b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50280):
Vocabulary size of the FALCON_MAMBA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`FalconMambaModel`].
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the embeddings and hidden states.
state_size (`int`, *optional*, defaults to 16): shape of the state space latents.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the model.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
The epsilon to use in the layer normalization layers.
pad_token_id (`int`, *optional*, defaults to 0):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
The id of the beginning of sentence token in the vocabulary.
eos_token_id (`int`, *optional*, defaults to 0):
The id of the end of sentence token in the vocabulary.
expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel.
use_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
hidden_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
initializer_range (`float`, *optional*, defaults to 0.1):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
residual_in_fp32 (`bool`, *optional*, defaults to `True`):
Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model
time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
time_step_scale (`float`, *optional*, defaults to 1.0):
Scale used used to scale `dt_proj.bias`.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_init_scheme (`float`, *optional*, defaults to `"random"`):
Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]`
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
rescale_prenorm_residual (`bool`, *optional*, defaults to `False`):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
use_mambapy (`bool`, *optional*, defaults to `False`):
Determines the fallback strategy during training if the CUDA-based official implementation of FalconMamba is not avaiable. If `True`, the falcon_mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.
Example:
```python
>>> from transformers import FalconMambaConfig, FalconMambaModel
>>> # Initializing a FalconMamba configuration
>>> configuration = FalconMambaConfig()
>>> # Initializing a model (with random weights) from the configuration
>>> model = FalconMambaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "falcon_mamba"
def __init__(
self,
vocab_size=50280,
hidden_size=768,
state_size=16,
num_hidden_layers=32,
layer_norm_epsilon=1e-5,
pad_token_id=0,
bos_token_id=0,
eos_token_id=0,
expand=2,
conv_kernel=4,
use_bias=False,
use_conv_bias=True,
hidden_act="silu",
initializer_range=0.1,
residual_in_fp32=True,
time_step_rank="auto",
time_step_scale=1.0,
time_step_min=0.001,
time_step_max=0.1,
time_step_init_scheme="random",
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
use_mambapy=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
self.num_hidden_layers = num_hidden_layers
self.layer_norm_epsilon = layer_norm_epsilon
self.conv_kernel = conv_kernel
self.expand = expand
self.intermediate_size = int(expand * self.hidden_size)
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.use_bias = use_bias
self.use_conv_bias = use_conv_bias
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank
self.time_step_scale = time_step_scale
self.time_step_min = time_step_min
self.time_step_max = time_step_max
self.time_step_init_scheme = time_step_init_scheme
self.time_step_floor = time_step_floor
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.use_mambapy = use_mambapy
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
This diff is collapsed.
...@@ -73,6 +73,7 @@ class MambaMixer(nn.Module): ...@@ -73,6 +73,7 @@ class MambaMixer(nn.Module):
def __init__(self, config: MambaConfig, layer_idx: int): def __init__(self, config: MambaConfig, layer_idx: int):
super().__init__() super().__init__()
self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.ssm_state_size = config.state_size self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel self.conv_kernel_size = config.conv_kernel
...@@ -364,7 +365,7 @@ class MambaPreTrainedModel(PreTrainedModel): ...@@ -364,7 +365,7 @@ class MambaPreTrainedModel(PreTrainedModel):
config_class = MambaConfig config_class = MambaConfig
base_model_prefix = "backbone" base_model_prefix = "backbone"
_no_split_modules = ["MambaBlock"] _no_split_modules = ["MambaBlock", "MambaMixer"]
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_is_stateful = True _is_stateful = True
......
...@@ -3895,6 +3895,27 @@ class FalconPreTrainedModel(metaclass=DummyObject): ...@@ -3895,6 +3895,27 @@ class FalconPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class FalconMambaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FalconMambaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FalconMambaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class FastSpeech2ConformerHifiGan(metaclass=DummyObject): class FastSpeech2ConformerHifiGan(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
This diff is collapsed.
...@@ -50,6 +50,8 @@ SPECIAL_CASES_TO_ALLOW = { ...@@ -50,6 +50,8 @@ SPECIAL_CASES_TO_ALLOW = {
"RecurrentGemmaConfig": ["block_types"], "RecurrentGemmaConfig": ["block_types"],
# used as in the config to define `intermediate_size` # used as in the config to define `intermediate_size`
"MambaConfig": ["expand"], "MambaConfig": ["expand"],
# used as in the config to define `intermediate_size`
"FalconMambaConfig": ["expand"],
# used as `self.bert_model = BertModel(config, ...)` # used as `self.bert_model = BertModel(config, ...)`
"DPRConfig": True, "DPRConfig": True,
"FuyuConfig": True, "FuyuConfig": True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment