"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "824d9c5f3adacd6042946c41be7076f465de0fec"
Commit 08e59d2c authored by Baber's avatar Baber
Browse files

fix batching

parent 7d286ad0
from typing import Optional, Union
from typing import List, Optional, Tuple, Union
import torch
......@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM
class RWKVWRAPPER(HFLM):
def __init__(
self,
pretrained,
pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
# To use the HF compatible variant
is_hf: bool = False,
**kwargs,
......@@ -20,6 +20,7 @@ class RWKVWRAPPER(HFLM):
assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
assert kwargs["tokenizer"] is not None, "`tokenizer` is required"
assert kwargs["batch_size"] == 1, "`batch_size` must be 1"
self.tokenizer = kwargs["tokenizer"]
self.pretrained = pretrained
super().__init__(
......@@ -63,7 +64,35 @@ class RWKVWRAPPER(HFLM):
os.environ["RWKV_CUDA_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1"
self._model = RWKV(model=self.pretrained, strategy=f"cuda {dtype}")
import os
from huggingface_hub import hf_hub_download
def download_file(repo_id, filename, local_dir="./downloads"):
os.makedirs(local_dir, exist_ok=True)
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
return path
for pretrained in [
"RWKV-x070-Pile-168M-20241120-ctx4096.pth",
"RWKV-x070-Pile-421M-20241127-ctx4096.pth",
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
]:
download_file(
repo_id="BlinkDL/rwkv-7-pile",
filename=pretrained,
local_dir="rwkv_model",
)
self._model = RWKV(
model=f"rwkv_model/{pretrained}", strategy=f"cuda {dtype}"
)
def _model_generate(self, context, max_length, stop, **generation_kwargs):
remove_arg = (
......@@ -82,7 +111,8 @@ class RWKVWRAPPER(HFLM):
prefill_token = prefill_ids[i : i + CHUNK_SIZE]
_, state = self.model(prefill_token, state)
gen_length = context.shape[1] - max_length
# hack: self.gen_len is set in tok_batch_encode
gen_length = self.gen_len
for i in range(gen_length):
logits, state = self.model([next_token], state)
next_token = torch.argmax(logits, dim=-1)
......@@ -114,3 +144,18 @@ class RWKVWRAPPER(HFLM):
use_cache=True,
**generation_kwargs,
)
def tok_batch_encode(
self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.gen_len = self.max_length - left_truncate_len
encoding = self.tokenizer(
strings,
truncation=truncation,
return_tensors="pt",
)
return encoding["input_ids"], encoding["attention_mask"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment