Commit 08e59d2c authored by Baber's avatar Baber
Browse files

fix batching

parent 7d286ad0
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
import torch
......@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM
class RWKVWRAPPER(HFLM):
def __init__(
self,
pretrained,
pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
# To use the HF compatible variant
is_hf: bool = False,
**kwargs,
......@@ -20,6 +20,7 @@ class RWKVWRAPPER(HFLM):
assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
assert kwargs["tokenizer"] is not None, "`tokenizer` is required"
assert kwargs["batch_size"] == 1, "`batch_size` must be 1"
self.tokenizer = kwargs["tokenizer"]
self.pretrained = pretrained
super().__init__(
......@@ -63,7 +64,35 @@ class RWKVWRAPPER(HFLM):
os.environ["RWKV_CUDA_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1"
self._model = RWKV(model=self.pretrained, strategy=f"cuda {dtype}")
import os
from huggingface_hub import hf_hub_download
def download_file(repo_id, filename, local_dir="./downloads"):
os.makedirs(local_dir, exist_ok=True)
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
return path
for pretrained in [
"RWKV-x070-Pile-168M-20241120-ctx4096.pth",
"RWKV-x070-Pile-421M-20241127-ctx4096.pth",
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
]:
download_file(
repo_id="BlinkDL/rwkv-7-pile",
filename=pretrained,
local_dir="rwkv_model",
)
self._model = RWKV(
model=f"rwkv_model/{pretrained}", strategy=f"cuda {dtype}"
)
def _model_generate(self, context, max_length, stop, **generation_kwargs):
remove_arg = (
......@@ -82,7 +111,8 @@ class RWKVWRAPPER(HFLM):
prefill_token = prefill_ids[i : i + CHUNK_SIZE]
_, state = self.model(prefill_token, state)
gen_length = context.shape[1] - max_length
# hack: self.gen_len is set in tok_batch_encode
gen_length = self.gen_len
for i in range(gen_length):
logits, state = self.model([next_token], state)
next_token = torch.argmax(logits, dim=-1)
......@@ -114,3 +144,18 @@ class RWKVWRAPPER(HFLM):
use_cache=True,
**generation_kwargs,
)
def tok_batch_encode(
self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.gen_len = self.max_length - left_truncate_len
encoding = self.tokenizer(
strings,
truncation=truncation,
return_tensors="pt",
)
return encoding["input_ids"], encoding["attention_mask"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment