Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
22f5ad26
Commit
22f5ad26
authored
Jan 23, 2025
by
Baber
Browse files
add rwkv
parent
bb9ed305
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
116 additions
and
0 deletions
+116
-0
lm_eval/models/rwkv.py
lm_eval/models/rwkv.py
+116
-0
No files found.
lm_eval/models/rwkv.py
0 → 100644
View file @
22f5ad26
from
typing
import
Optional
,
Union
import
torch
import
lm_eval.models.utils
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.huggingface
import
HFLM
@
register_model
(
"rwkv"
)
class
MambaLMWrapper
(
HFLM
):
def
__init__
(
self
,
pretrained
=
""
,
# To use the HF compatible variant
is_hf
:
bool
=
False
,
**
kwargs
,
)
->
None
:
if
"backend"
in
kwargs
:
assert
kwargs
[
"backend"
]
==
"causal"
self
.
is_hf
=
is_hf
or
(
True
if
pretrained
.
endswith
(
"hf"
)
else
False
)
self
.
tokenizer
=
kwargs
[
"tokenizer"
]
self
.
pretrained
=
pretrained
super
().
__init__
(
pretrained
=
pretrained
,
# set appropriate defaults for tokenizer, max length, etc
backend
=
kwargs
.
pop
(
"backend"
,
"causal"
),
tokenizer
=
self
.
tokenizer
,
max_length
=
kwargs
.
pop
(
"max_length"
,
4096
),
**
kwargs
,
)
def
_get_config
(
self
,
pretrained
:
str
,
**
kwargs
,
)
->
None
:
if
self
.
is_hf
:
super
().
_get_config
(
pretrained
,
**
kwargs
)
else
:
self
.
_config
=
{}
def
_create_model
(
self
,
pretrained
:
str
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"fp16"
,
**
kwargs
,
)
->
None
:
if
self
.
is_hf
:
super
().
_create_model
(
pretrained
,
dtype
=
dtype
,
**
kwargs
)
else
:
try
:
from
rwkv.model
import
RWKV
except
ModuleNotFoundError
as
exception
:
raise
type
(
exception
)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed.
\
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`"
,
)
import
os
os
.
environ
[
"RWKV_JIT_ON"
]
=
"1"
os
.
environ
[
"RWKV_CUDA_ON"
]
=
"1"
os
.
environ
[
"RWKV_V7_ON"
]
=
"1"
self
.
_model
=
RWKV
(
model
=
self
.
pretrained
,
strategy
=
f
"cuda
{
dtype
}
"
)
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
remove_arg
=
(
[
"attention_mask"
]
if
self
.
is_hf
else
[
"do_sample"
,
"attention_mask"
]
)
for
key
in
remove_arg
:
if
key
in
generation_kwargs
:
generation_kwargs
.
pop
(
key
)
all_outputs
=
[]
if
not
self
.
is_hf
:
CHUNK_SIZE
=
4096
prefill_ids
,
next_token
=
context
[:
-
1
],
context
[
-
1
]
state
=
None
for
i
in
range
(
0
,
len
(
prefill_ids
),
CHUNK_SIZE
):
prefill_token
=
prefill_ids
[
i
:
i
+
CHUNK_SIZE
]
_
,
state
=
self
.
model
(
prefill_token
,
state
)
gen_length
=
context
.
shape
[
1
]
-
max_length
for
i
in
range
(
gen_length
):
logits
,
state
=
self
.
model
([
next_token
],
state
)
next_token
=
torch
.
argmax
(
logits
,
dim
=-
1
)
all_outputs
.
append
(
next_token
)
return
torch
.
cat
(
all_outputs
)
else
:
stopping_criteria
=
lm_eval
.
models
.
utils
.
stop_sequences_criteria
(
self
.
tokenizer
,
stop
,
context
.
shape
[
1
],
context
.
shape
[
0
],
)
generation_kwargs
[
"temperature"
]
=
generation_kwargs
.
get
(
"temperature"
,
0.0
)
do_sample
=
generation_kwargs
.
get
(
"do_sample"
,
None
)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if
generation_kwargs
.
get
(
"temperature"
)
==
0.0
and
do_sample
is
None
:
generation_kwargs
[
"do_sample"
]
=
do_sample
=
False
if
do_sample
is
False
and
generation_kwargs
.
get
(
"temperature"
)
==
0.0
:
generation_kwargs
.
pop
(
"temperature"
)
return
self
.
model
.
generate
(
input_ids
=
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
tokenizer
.
pad_token_id
,
use_cache
=
True
,
**
generation_kwargs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment