Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0f5dc265
Unverified
Commit
0f5dc265
authored
Nov 18, 2024
by
Baber Abbasi
Committed by
GitHub
Nov 18, 2024
Browse files
Add mamba hf to `mamba_ssm` (#2496)
* add hf mamba to mamba_lm * fix _model_generate for hf
parent
cbc31eb8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
33 deletions
+72
-33
lm_eval/models/mamba_lm.py
lm_eval/models/mamba_lm.py
+72
-33
No files found.
lm_eval/models/mamba_lm.py
View file @
0f5dc265
...
...
@@ -12,6 +12,8 @@ class MambaLMWrapper(HFLM):
def
__init__
(
self
,
pretrained
=
"state-spaces/mamba-130m"
,
# To use the HF compatible variant
is_hf
:
bool
=
False
,
**
kwargs
,
)
->
None
:
"""
...
...
@@ -52,7 +54,7 @@ class MambaLMWrapper(HFLM):
if
"backend"
in
kwargs
:
# mamba currently only supports causal models
assert
kwargs
[
"backend"
]
==
"causal"
self
.
is_hf
=
is_hf
or
(
True
if
pretrained
.
endswith
(
"hf"
)
else
False
)
super
().
__init__
(
pretrained
=
pretrained
,
# set appropriate defaults for tokenizer, max length, etc
...
...
@@ -67,15 +69,18 @@ class MambaLMWrapper(HFLM):
pretrained
:
str
,
**
kwargs
,
)
->
None
:
try
:
from
mamba_ssm.utils.hf
import
load_config_hf
# noqa: F811
except
ModuleNotFoundError
as
exception
:
raise
type
(
exception
)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed.
\
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`"
,
)
self
.
_config
=
load_config_hf
(
pretrained
)
if
self
.
is_hf
:
super
().
_get_config
(
pretrained
,
**
kwargs
)
else
:
try
:
from
mamba_ssm.utils.hf
import
load_config_hf
# noqa: F811
except
ModuleNotFoundError
as
exception
:
raise
type
(
exception
)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed.
\
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`"
,
)
self
.
_config
=
load_config_hf
(
pretrained
)
def
_create_model
(
self
,
...
...
@@ -86,24 +91,32 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# Mamba does not support arbitrary HF from_pretrained() args
**
kwargs
,
)
->
None
:
try
:
from
mamba_ssm.models.mixer_seq_simple
import
MambaLMHeadModel
# noqa: F811
except
ModuleNotFoundError
as
exception
:
raise
type
(
exception
)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed.
\
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`"
,
if
self
.
is_hf
:
super
().
_create_model
(
pretrained
,
dtype
=
dtype
,
**
kwargs
)
else
:
try
:
from
mamba_ssm.models.mixer_seq_simple
import
(
MambaLMHeadModel
,
# noqa: F811
)
except
ModuleNotFoundError
as
exception
:
raise
type
(
exception
)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed.
\
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`"
,
)
self
.
_model
=
MambaLMHeadModel
.
from_pretrained
(
pretrained
,
device
=
self
.
_device
,
dtype
=
torch
.
float16
if
dtype
==
"auto"
else
lm_eval
.
models
.
utils
.
get_dtype
(
dtype
),
)
self
.
_model
=
MambaLMHeadModel
.
from_pretrained
(
pretrained
,
device
=
self
.
_device
,
dtype
=
torch
.
float16
if
dtype
==
"auto"
else
lm_eval
.
models
.
utils
.
get_dtype
(
dtype
),
)
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
for
key
in
(
"do_sample"
,
"attention_mask"
):
remove_arg
=
(
[
"attention_mask"
]
if
self
.
is_hf
else
[
"do_sample"
,
"attention_mask"
]
)
for
key
in
remove_arg
:
if
key
in
generation_kwargs
:
generation_kwargs
.
pop
(
key
)
...
...
@@ -116,11 +129,37 @@ please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba
# self.tokenizer, stop, 1, context.shape[0]
# )
return
self
.
model
.
generate
(
input_ids
=
context
,
max_length
=
max_length
,
# stopping_criteria=stopping_criteria,
# pad_token_id=self.tokenizer.pad_token_id,
# use_cache=True,
**
generation_kwargs
,
)
if
not
self
.
is_hf
:
return
self
.
model
.
generate
(
input_ids
=
context
,
max_length
=
max_length
,
# stopping_criteria=stopping_criteria,
# pad_token_id=self.tokenizer.pad_token_id,
# use_cache=True,
**
generation_kwargs
,
)
else
:
stopping_criteria
=
lm_eval
.
models
.
utils
.
stop_sequences_criteria
(
self
.
tokenizer
,
stop
,
context
.
shape
[
1
],
context
.
shape
[
0
],
)
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
do_sample
=
generation_kwargs
.
get
(
"do_sample"
,
None
)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if
generation_kwargs
.
get
(
"temperature"
)
==
0.0
and
do_sample
is
None
:
generation_kwargs
[
"do_sample"
]
=
do_sample
=
False
if
do_sample
is
False
and
generation_kwargs
.
get
(
"temperature"
)
==
0.0
:
generation_kwargs
.
pop
(
"temperature"
)
return
self
.
model
.
generate
(
input_ids
=
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
tokenizer
.
pad_token_id
,
use_cache
=
True
,
**
generation_kwargs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment