Commit 2b072879 authored by Baber's avatar Baber
Browse files

fixes

parent df94cfdd
......@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM
class RWKVWRAPPER(HFLM):
def __init__(
self,
pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096",
# To use the HF compatible variant
is_hf: bool = False,
**kwargs,
......@@ -80,13 +80,13 @@ class RWKVWRAPPER(HFLM):
return path
for pretrained in [
"RWKV-x070-Pile-168M-20241120-ctx4096.pth",
"RWKV-x070-Pile-421M-20241127-ctx4096.pth",
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
"RWKV-x070-Pile-168M-20241120-ctx4096",
"RWKV-x070-Pile-421M-20241127-ctx4096",
"RWKV-x070-Pile-1.47B-20241210-ctx4096",
]:
download_file(
repo_id="BlinkDL/rwkv-7-pile",
filename=pretrained,
filename=pretrained + ".pth",
local_dir="rwkv_model",
)
......@@ -104,7 +104,7 @@ class RWKVWRAPPER(HFLM):
all_outputs = []
if not self.is_hf:
CHUNK_SIZE = 4096
context = context.squeeze()
context = context.squeeze().tolist()
prefill_ids, next_token = context[:-1], context[-1]
state = None
for i in range(0, len(prefill_ids), CHUNK_SIZE):
......@@ -118,7 +118,7 @@ class RWKVWRAPPER(HFLM):
next_token = torch.argmax(logits, dim=-1)
all_outputs.append(next_token)
return torch.cat(all_outputs)
return torch.stack(all_outputs).unsqueeze(0)
else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment