Unverified Commit fb1c62e9 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Add Mamba`] Adds support for the `Mamba` models (#28094)



* initial-commit

* start cleaning

* small nits

* small nits

* current updates

* add kernels

* small refactoring little step

* add comments

* styling

* nit

* nits

* Style

* Small changes

* Push dummy mambda simple slow

* nit

* Use original names

* Use original names and remove norm

* Updates for inference params

* Style nd updates

* nits

* Match logits

* Add a test

* Add expected generated text

* nits doc, imports and styling

* style

* oups

* dont install kernels, invite users to install the required kernels

* let use use the original packages

* styling

* nits

* fix some copieds

* update doc

* fix-copies

* styling done

* nits

* fix import check

* run but wrong cuda ress

* mamba CUDA works :)

* fix the fast path

* config naming nits

* conversion script is not required at this stage

* finish fixing the fast path: generation make sense now!

* nit

* Let's start working on the CIs

* style

* better style

* more nits

* test nit

* quick fix for now

* nits

* nit

* nit

* nit

* nits

* update test rest

* fixup

* update test

* nit

* some fixes

* nits

* update test values

* fix styling

* nit

* support peft

* integrations tests require torchg

* also add slow markers

* styling

* chose forward wisely

* nits

* update tests

* fix gradient checkpointing

* fixup

* nit

* fix doc

* check copies

* fix the docstring

* fix some more tests

* style

* fix beam search

* add init schene

* update

* nit

* fix

* fixup the doc

* fix the doc

* fixup

* tentative update but slow is no longer good

* nit

* should we always use float32?

* nits

* revert wrong changes

* res in float32

* cleanup

* skip fmt for now

* update generation values

* update test values running original model

* fixup

* update tests + rename inference_params to cache_params + make sure training does not use cache_params

* small nits

* more nits

* fix final CIs

* style

* nit doc

* I hope final doc nits

* nit

* 🫠

* final touch!

* fix torch import

* Apply suggestions from code review
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* Apply suggestions from code review

* fix fix and fix

* fix base model prefix!

* nit

* Update src/transformers/models/mamba/__init__.py

* Update docs/source/en/model_doc/mamba.md
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* nit

---------
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent 87a0783d
This diff is collapsed.
...@@ -5022,6 +5022,30 @@ class M2M100PreTrainedModel(metaclass=DummyObject): ...@@ -5022,6 +5022,30 @@ class M2M100PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class MambaForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MambaModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MambaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MarianForCausalLM(metaclass=DummyObject): class MarianForCausalLM(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -307,6 +307,27 @@ def is_torch_cuda_available(): ...@@ -307,6 +307,27 @@ def is_torch_cuda_available():
return False return False
def is_mamba_ssm_available():
if is_torch_available():
import torch
if not torch.cuda.is_available():
return False
else:
return _is_package_available("mamba_ssm")
return False
def is_causal_conv1d_available():
if is_torch_available():
import torch
if not torch.cuda.is_available():
return False
return _is_package_available("causal_conv1d")
return False
def is_torch_mps_available(): def is_torch_mps_available():
if is_torch_available(): if is_torch_available():
import torch import torch
......
This diff is collapsed.
...@@ -34,6 +34,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING ...@@ -34,6 +34,8 @@ CONFIG_MAPPING = transformers.models.auto.configuration_auto.CONFIG_MAPPING
SPECIAL_CASES_TO_ALLOW = { SPECIAL_CASES_TO_ALLOW = {
# used to compute the property `self.chunk_length` # used to compute the property `self.chunk_length`
"EncodecConfig": ["overlap"], "EncodecConfig": ["overlap"],
# used as in the config to define `intermediate_size`
"MambaConfig": ["expand"],
# used as `self.bert_model = BertModel(config, ...)` # used as `self.bert_model = BertModel(config, ...)`
"DPRConfig": True, "DPRConfig": True,
"FuyuConfig": True, "FuyuConfig": True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment