Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0cce16be
"docs/source/vscode:/vscode.git/clone" did not exist on "f98b745a81df1613a9c5f1d5986456663f86c457"
Commit
0cce16be
authored
Jan 23, 2025
by
Baber
Browse files
fix gen
parent
2b072879
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
2 deletions
+13
-2
lm_eval/models/rwkvwrapper.py
lm_eval/models/rwkvwrapper.py
+13
-2
No files found.
lm_eval/models/rwkvwrapper.py
View file @
0cce16be
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
lm_eval.models.utils
import
lm_eval.models.utils
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
...
@@ -93,7 +94,14 @@ class RWKVWRAPPER(HFLM):
...
@@ -93,7 +94,14 @@ class RWKVWRAPPER(HFLM):
self
.
_model
=
RWKV
(
model
=
f
"rwkv_model/
{
pretrained
}
"
,
strategy
=
"cuda fp16"
)
self
.
_model
=
RWKV
(
model
=
f
"rwkv_model/
{
pretrained
}
"
,
strategy
=
"cuda fp16"
)
self
.
_model
.
tie_weights
=
lambda
:
None
self
.
_model
.
tie_weights
=
lambda
:
None
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
def
_model_generate
(
self
,
context
:
"torch.tensor"
,
max_length
:
int
,
stop
:
list
[
str
],
**
generation_kwargs
,
)
->
"torch.tensor"
:
context_len
=
context
.
shape
[
1
]
remove_arg
=
(
remove_arg
=
(
[
"attention_mask"
]
if
self
.
is_hf
else
[
"do_sample"
,
"attention_mask"
]
[
"attention_mask"
]
if
self
.
is_hf
else
[
"do_sample"
,
"attention_mask"
]
)
)
...
@@ -118,7 +126,10 @@ class RWKVWRAPPER(HFLM):
...
@@ -118,7 +126,10 @@ class RWKVWRAPPER(HFLM):
next_token
=
torch
.
argmax
(
logits
,
dim
=-
1
)
next_token
=
torch
.
argmax
(
logits
,
dim
=-
1
)
all_outputs
.
append
(
next_token
)
all_outputs
.
append
(
next_token
)
return
torch
.
stack
(
all_outputs
).
unsqueeze
(
0
)
# return context + gen (context gets trimmed downstream)
return
F
.
pad
(
torch
.
stack
(
all_outputs
).
to
(
"cpu"
),
(
context_len
,
0
)
).
unsqueeze
(
0
)
else
:
else
:
stopping_criteria
=
lm_eval
.
models
.
utils
.
stop_sequences_criteria
(
stopping_criteria
=
lm_eval
.
models
.
utils
.
stop_sequences_criteria
(
self
.
tokenizer
,
self
.
tokenizer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment