Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
08e59d2c
Commit
08e59d2c
authored
Jan 23, 2025
by
Baber
Browse files
fix batching
parent
7d286ad0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
4 deletions
+49
-4
lm_eval/models/rwkvwrapper.py
lm_eval/models/rwkvwrapper.py
+49
-4
No files found.
lm_eval/models/rwkvwrapper.py
View file @
08e59d2c
from
typing
import
Optional
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM
class
RWKVWRAPPER
(
HFLM
):
def
__init__
(
self
,
pretrained
,
pretrained
=
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth"
,
# To use the HF compatible variant
is_hf
:
bool
=
False
,
**
kwargs
,
...
...
@@ -20,6 +20,7 @@ class RWKVWRAPPER(HFLM):
assert
kwargs
[
"backend"
]
==
"causal"
self
.
is_hf
=
is_hf
or
(
True
if
pretrained
.
endswith
(
"hf"
)
else
False
)
assert
kwargs
[
"tokenizer"
]
is
not
None
,
"`tokenizer` is required"
assert
kwargs
[
"batch_size"
]
==
1
,
"`batch_size` must be 1"
self
.
tokenizer
=
kwargs
[
"tokenizer"
]
self
.
pretrained
=
pretrained
super
().
__init__
(
...
...
@@ -63,7 +64,35 @@ class RWKVWRAPPER(HFLM):
os
.
environ
[
"RWKV_CUDA_ON"
]
=
"1"
os
.
environ
[
"RWKV_V7_ON"
]
=
"1"
self
.
_model
=
RWKV
(
model
=
self
.
pretrained
,
strategy
=
f
"cuda
{
dtype
}
"
)
import
os
from
huggingface_hub
import
hf_hub_download
def
download_file
(
repo_id
,
filename
,
local_dir
=
"./downloads"
):
os
.
makedirs
(
local_dir
,
exist_ok
=
True
)
path
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
filename
,
local_dir
=
local_dir
,
local_dir_use_symlinks
=
False
,
)
return
path
for
pretrained
in
[
"RWKV-x070-Pile-168M-20241120-ctx4096.pth"
,
"RWKV-x070-Pile-421M-20241127-ctx4096.pth"
,
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth"
,
]:
download_file
(
repo_id
=
"BlinkDL/rwkv-7-pile"
,
filename
=
pretrained
,
local_dir
=
"rwkv_model"
,
)
self
.
_model
=
RWKV
(
model
=
f
"rwkv_model/
{
pretrained
}
"
,
strategy
=
f
"cuda
{
dtype
}
"
)
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
remove_arg
=
(
...
...
@@ -82,7 +111,8 @@ class RWKVWRAPPER(HFLM):
prefill_token
=
prefill_ids
[
i
:
i
+
CHUNK_SIZE
]
_
,
state
=
self
.
model
(
prefill_token
,
state
)
gen_length
=
context
.
shape
[
1
]
-
max_length
# hack: self.gen_len is set in tok_batch_encode
gen_length
=
self
.
gen_len
for
i
in
range
(
gen_length
):
logits
,
state
=
self
.
model
([
next_token
],
state
)
next_token
=
torch
.
argmax
(
logits
,
dim
=-
1
)
...
...
@@ -114,3 +144,18 @@ class RWKVWRAPPER(HFLM):
use_cache
=
True
,
**
generation_kwargs
,
)
def
tok_batch_encode
(
self
,
strings
:
List
[
str
],
padding_side
:
str
=
"left"
,
left_truncate_len
:
int
=
None
,
truncation
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
gen_len
=
self
.
max_length
-
left_truncate_len
encoding
=
self
.
tokenizer
(
strings
,
truncation
=
truncation
,
return_tensors
=
"pt"
,
)
return
encoding
[
"input_ids"
],
encoding
[
"attention_mask"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment