Unverified Commit 0f5dc265 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Add mamba hf to `mamba_ssm` (#2496)

* add hf mamba to mamba_lm

* fix _model_generate for hf
parent cbc31eb8
......@@ -12,6 +12,8 @@ class MambaLMWrapper(HFLM):
def __init__(
self,
pretrained="state-spaces/mamba-130m",
# To use the HF compatible variant
is_hf: bool = False,
**kwargs,
) -> None:
"""
......@@ -52,7 +54,7 @@ class MambaLMWrapper(HFLM):
if "backend" in kwargs:
# mamba currently only supports causal models
assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
super().__init__(
pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc
......@@ -67,15 +69,18 @@ class MambaLMWrapper(HFLM):
pretrained: str,
**kwargs,
) -> None:
try:
from mamba_ssm.utils.hf import load_config_hf # noqa: F811
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._config = load_config_hf(pretrained)
if self.is_hf:
super()._get_config(pretrained, **kwargs)
else:
try:
from mamba_ssm.utils.hf import load_config_hf # noqa: F811
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._config = load_config_hf(pretrained)
def _create_model(
self,
......@@ -86,24 +91,32 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# Mamba does not support arbitrary HF from_pretrained() args
**kwargs,
) -> None:
try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
if self.is_hf:
super()._create_model(pretrained, dtype=dtype, **kwargs)
else:
try:
from mamba_ssm.models.mixer_seq_simple import (
MambaLMHeadModel, # noqa: F811
)
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
self._model = MambaLMHeadModel.from_pretrained(
pretrained,
device=self._device,
dtype=torch.float16
if dtype == "auto"
else lm_eval.models.utils.get_dtype(dtype),
)
self._model = MambaLMHeadModel.from_pretrained(
pretrained,
device=self._device,
dtype=torch.float16
if dtype == "auto"
else lm_eval.models.utils.get_dtype(dtype),
)
def _model_generate(self, context, max_length, stop, **generation_kwargs):
for key in ("do_sample", "attention_mask"):
remove_arg = (
["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"]
)
for key in remove_arg:
if key in generation_kwargs:
generation_kwargs.pop(key)
......@@ -116,11 +129,37 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# self.tokenizer, stop, 1, context.shape[0]
# )
return self.model.generate(
input_ids=context,
max_length=max_length,
# stopping_criteria=stopping_criteria,
# pad_token_id=self.tokenizer.pad_token_id,
# use_cache=True,
**generation_kwargs,
)
if not self.is_hf:
return self.model.generate(
input_ids=context,
max_length=max_length,
# stopping_criteria=stopping_criteria,
# pad_token_id=self.tokenizer.pad_token_id,
# use_cache=True,
**generation_kwargs,
)
else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer,
stop,
context.shape[1],
context.shape[0],
)
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
return self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment