Unverified Commit 0f5dc265 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Add mamba hf to `mamba_ssm` (#2496)

* add hf mamba to mamba_lm

* fix _model_generate for hf
parent cbc31eb8
...@@ -12,6 +12,8 @@ class MambaLMWrapper(HFLM): ...@@ -12,6 +12,8 @@ class MambaLMWrapper(HFLM):
def __init__( def __init__(
self, self,
pretrained="state-spaces/mamba-130m", pretrained="state-spaces/mamba-130m",
# To use the HF compatible variant
is_hf: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
...@@ -52,7 +54,7 @@ class MambaLMWrapper(HFLM): ...@@ -52,7 +54,7 @@ class MambaLMWrapper(HFLM):
if "backend" in kwargs: if "backend" in kwargs:
# mamba currently only supports causal models # mamba currently only supports causal models
assert kwargs["backend"] == "causal" assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
super().__init__( super().__init__(
pretrained=pretrained, pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc # set appropriate defaults for tokenizer, max length, etc
...@@ -67,12 +69,15 @@ class MambaLMWrapper(HFLM): ...@@ -67,12 +69,15 @@ class MambaLMWrapper(HFLM):
pretrained: str, pretrained: str,
**kwargs, **kwargs,
) -> None: ) -> None:
if self.is_hf:
super()._get_config(pretrained, **kwargs)
else:
try: try:
from mamba_ssm.utils.hf import load_config_hf # noqa: F811 from mamba_ssm.utils.hf import load_config_hf # noqa: F811
except ModuleNotFoundError as exception: except ModuleNotFoundError as exception:
raise type(exception)( raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
) )
self._config = load_config_hf(pretrained) self._config = load_config_hf(pretrained)
...@@ -86,12 +91,17 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -86,12 +91,17 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# Mamba does not support arbitrary HF from_pretrained() args # Mamba does not support arbitrary HF from_pretrained() args
**kwargs, **kwargs,
) -> None: ) -> None:
if self.is_hf:
super()._create_model(pretrained, dtype=dtype, **kwargs)
else:
try: try:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel # noqa: F811 from mamba_ssm.models.mixer_seq_simple import (
MambaLMHeadModel, # noqa: F811
)
except ModuleNotFoundError as exception: except ModuleNotFoundError as exception:
raise type(exception)( raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
) )
self._model = MambaLMHeadModel.from_pretrained( self._model = MambaLMHeadModel.from_pretrained(
...@@ -103,7 +113,10 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -103,7 +113,10 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
) )
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
for key in ("do_sample", "attention_mask"): remove_arg = (
["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"]
)
for key in remove_arg:
if key in generation_kwargs: if key in generation_kwargs:
generation_kwargs.pop(key) generation_kwargs.pop(key)
...@@ -116,6 +129,7 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -116,6 +129,7 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# self.tokenizer, stop, 1, context.shape[0] # self.tokenizer, stop, 1, context.shape[0]
# ) # )
if not self.is_hf:
return self.model.generate( return self.model.generate(
input_ids=context, input_ids=context,
max_length=max_length, max_length=max_length,
...@@ -124,3 +138,28 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba ...@@ -124,3 +138,28 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# use_cache=True, # use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer,
stop,
context.shape[1],
context.shape[0],
)
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
return self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment