Unverified Commit 4f27ee93 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Mamba doc`] Post merge updates (#29472)

* post merge update

* nit

* oups
parent 0290ec19
......@@ -44,11 +44,8 @@ The original code can be found [here](https://github.com/state-spaces/mamba).
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer.pad_token = tokenizer.eos_token
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
model.config.use_cache = True
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
......@@ -63,8 +60,8 @@ from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
model_id = "ArthurZ/mamba-2.8b"
tokenizer = AutoTokenizer.from_pretrained(model_id, pad_token ="<s>")
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
......@@ -77,7 +74,7 @@ training_args = TrainingArguments(
)
lora_config = LoraConfig(
r=8,
target_modules="all-linear",
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
......
......@@ -53,7 +53,7 @@ is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
_CHECKPOINT_FOR_DOC = "ArthurZ/mamba-130m"
_CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf"
_CONFIG_FOR_DOC = "MambaConfig"
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [] # See all Mamba models at https://huggingface.co/models?filter=mamba
......@@ -605,7 +605,7 @@ class MambaForCausalLM(MambaPreTrainedModel):
def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs
) -> Dict[str, Any]:
model_kwargs["cache_params"] = outputs["cache_params"]
model_kwargs["cache_params"] = outputs.get("cache_params", None)
return model_kwargs
def prepare_inputs_for_generation(
......
......@@ -406,15 +406,15 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch
class MambaIntegrationTests(unittest.TestCase):
def setUp(self):
self.model_id = "ArthurZ/mamba-2.8b"
self.model_id = "state-spaces/mamba-2.8b-hf"
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
@parameterized.expand([(torch_device,), ("cpu",)])
def test_simple_generate(self, device):
tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer.pad_token = tokenizer.eos_token
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16)
model.to(device)
model.config.use_cache = True
input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device)
......@@ -444,7 +444,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is John and I am a newbie to the world"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
......@@ -457,7 +457,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is\n\nI am a\n\nI am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-790m-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=10)
output_sentence = self.tokenizer.decode(output[0].tolist())
......@@ -470,7 +470,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is John and I am a\n\nI am a single father of a beautiful daughter. I am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=20)
output_sentence = self.tokenizer.decode(output[0].tolist())
......@@ -483,7 +483,7 @@ class MambaIntegrationTests(unittest.TestCase):
expected_output = "Hello my name is John and I am a new member of this forum. I am a retired Marine and I am a member of the Marine Corps League. I am a"
input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device)
model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(device)
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-2.8b-hf", torch_dtype=torch.float16).to(device)
output = model.generate(input_ids, max_new_tokens=30)
output_sentence = self.tokenizer.decode(output[0].tolist())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment