Commit 08e59d2c authored by Baber's avatar Baber
Browse files

fix batching

parent 7d286ad0
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM ...@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM
class RWKVWRAPPER(HFLM): class RWKVWRAPPER(HFLM):
def __init__( def __init__(
self, self,
pretrained, pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
# To use the HF compatible variant # To use the HF compatible variant
is_hf: bool = False, is_hf: bool = False,
**kwargs, **kwargs,
...@@ -20,6 +20,7 @@ class RWKVWRAPPER(HFLM): ...@@ -20,6 +20,7 @@ class RWKVWRAPPER(HFLM):
assert kwargs["backend"] == "causal" assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False) self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
assert kwargs["tokenizer"] is not None, "`tokenizer` is required" assert kwargs["tokenizer"] is not None, "`tokenizer` is required"
assert kwargs["batch_size"] == 1, "`batch_size` must be 1"
self.tokenizer = kwargs["tokenizer"] self.tokenizer = kwargs["tokenizer"]
self.pretrained = pretrained self.pretrained = pretrained
super().__init__( super().__init__(
...@@ -63,7 +64,35 @@ class RWKVWRAPPER(HFLM): ...@@ -63,7 +64,35 @@ class RWKVWRAPPER(HFLM):
os.environ["RWKV_CUDA_ON"] = "1" os.environ["RWKV_CUDA_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1" os.environ["RWKV_V7_ON"] = "1"
self._model = RWKV(model=self.pretrained, strategy=f"cuda {dtype}") import os
from huggingface_hub import hf_hub_download
def download_file(repo_id, filename, local_dir="./downloads"):
os.makedirs(local_dir, exist_ok=True)
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
return path
for pretrained in [
"RWKV-x070-Pile-168M-20241120-ctx4096.pth",
"RWKV-x070-Pile-421M-20241127-ctx4096.pth",
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
]:
download_file(
repo_id="BlinkDL/rwkv-7-pile",
filename=pretrained,
local_dir="rwkv_model",
)
self._model = RWKV(
model=f"rwkv_model/{pretrained}", strategy=f"cuda {dtype}"
)
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
remove_arg = ( remove_arg = (
...@@ -82,7 +111,8 @@ class RWKVWRAPPER(HFLM): ...@@ -82,7 +111,8 @@ class RWKVWRAPPER(HFLM):
prefill_token = prefill_ids[i : i + CHUNK_SIZE] prefill_token = prefill_ids[i : i + CHUNK_SIZE]
_, state = self.model(prefill_token, state) _, state = self.model(prefill_token, state)
gen_length = context.shape[1] - max_length # hack: self.gen_len is set in tok_batch_encode
gen_length = self.gen_len
for i in range(gen_length): for i in range(gen_length):
logits, state = self.model([next_token], state) logits, state = self.model([next_token], state)
next_token = torch.argmax(logits, dim=-1) next_token = torch.argmax(logits, dim=-1)
...@@ -114,3 +144,18 @@ class RWKVWRAPPER(HFLM): ...@@ -114,3 +144,18 @@ class RWKVWRAPPER(HFLM):
use_cache=True, use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
def tok_batch_encode(
self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.gen_len = self.max_length - left_truncate_len
encoding = self.tokenizer(
strings,
truncation=truncation,
return_tensors="pt",
)
return encoding["input_ids"], encoding["attention_mask"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment