Commit 2b072879 authored by Baber's avatar Baber
Browse files

fixes

parent df94cfdd
...@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM ...@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM
class RWKVWRAPPER(HFLM): class RWKVWRAPPER(HFLM):
def __init__( def __init__(
self, self,
pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096.pth", pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096",
# To use the HF compatible variant # To use the HF compatible variant
is_hf: bool = False, is_hf: bool = False,
**kwargs, **kwargs,
...@@ -80,13 +80,13 @@ class RWKVWRAPPER(HFLM): ...@@ -80,13 +80,13 @@ class RWKVWRAPPER(HFLM):
return path return path
for pretrained in [ for pretrained in [
"RWKV-x070-Pile-168M-20241120-ctx4096.pth", "RWKV-x070-Pile-168M-20241120-ctx4096",
"RWKV-x070-Pile-421M-20241127-ctx4096.pth", "RWKV-x070-Pile-421M-20241127-ctx4096",
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth", "RWKV-x070-Pile-1.47B-20241210-ctx4096",
]: ]:
download_file( download_file(
repo_id="BlinkDL/rwkv-7-pile", repo_id="BlinkDL/rwkv-7-pile",
filename=pretrained, filename=pretrained + ".pth",
local_dir="rwkv_model", local_dir="rwkv_model",
) )
...@@ -104,7 +104,7 @@ class RWKVWRAPPER(HFLM): ...@@ -104,7 +104,7 @@ class RWKVWRAPPER(HFLM):
all_outputs = [] all_outputs = []
if not self.is_hf: if not self.is_hf:
CHUNK_SIZE = 4096 CHUNK_SIZE = 4096
context = context.squeeze() context = context.squeeze().tolist()
prefill_ids, next_token = context[:-1], context[-1] prefill_ids, next_token = context[:-1], context[-1]
state = None state = None
for i in range(0, len(prefill_ids), CHUNK_SIZE): for i in range(0, len(prefill_ids), CHUNK_SIZE):
...@@ -118,7 +118,7 @@ class RWKVWRAPPER(HFLM): ...@@ -118,7 +118,7 @@ class RWKVWRAPPER(HFLM):
next_token = torch.argmax(logits, dim=-1) next_token = torch.argmax(logits, dim=-1)
all_outputs.append(next_token) all_outputs.append(next_token)
return torch.cat(all_outputs) return torch.stack(all_outputs).unsqueeze(0)
else: else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer, self.tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment