"...gpu/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "e9143cd7e644239991b259a2f17a0a035e2b8c5a"
Commit 08e59d2c authored by Baber's avatar Baber
Browse files

fix batching

parent 7d286ad0
from typing import Optional, Union from typing import List, Optional, Tuple, Union
import torch import torch
...@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM ...@@ -11,7 +11,7 @@ from lm_eval.models.huggingface import HFLM
class RWKVWRAPPER(HFLM): class RWKVWRAPPER(HFLM):
def __init__( def __init__(
self, self,
pretrained, pretrained="RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
# To use the HF compatible variant # To use the HF compatible variant
is_hf: bool = False, is_hf: bool = False,
**kwargs, **kwargs,
...@@ -20,6 +20,7 @@ class RWKVWRAPPER(HFLM): ...@@ -20,6 +20,7 @@ class RWKVWRAPPER(HFLM):
assert kwargs["backend"] == "causal" assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False) self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
assert kwargs["tokenizer"] is not None, "`tokenizer` is required" assert kwargs["tokenizer"] is not None, "`tokenizer` is required"
assert kwargs["batch_size"] == 1, "`batch_size` must be 1"
self.tokenizer = kwargs["tokenizer"] self.tokenizer = kwargs["tokenizer"]
self.pretrained = pretrained self.pretrained = pretrained
super().__init__( super().__init__(
...@@ -63,7 +64,35 @@ class RWKVWRAPPER(HFLM): ...@@ -63,7 +64,35 @@ class RWKVWRAPPER(HFLM):
os.environ["RWKV_CUDA_ON"] = "1" os.environ["RWKV_CUDA_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1" os.environ["RWKV_V7_ON"] = "1"
self._model = RWKV(model=self.pretrained, strategy=f"cuda {dtype}") import os
from huggingface_hub import hf_hub_download
def download_file(repo_id, filename, local_dir="./downloads"):
os.makedirs(local_dir, exist_ok=True)
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir,
local_dir_use_symlinks=False,
)
return path
for pretrained in [
"RWKV-x070-Pile-168M-20241120-ctx4096.pth",
"RWKV-x070-Pile-421M-20241127-ctx4096.pth",
"RWKV-x070-Pile-1.47B-20241210-ctx4096.pth",
]:
download_file(
repo_id="BlinkDL/rwkv-7-pile",
filename=pretrained,
local_dir="rwkv_model",
)
self._model = RWKV(
model=f"rwkv_model/{pretrained}", strategy=f"cuda {dtype}"
)
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
remove_arg = ( remove_arg = (
...@@ -82,7 +111,8 @@ class RWKVWRAPPER(HFLM): ...@@ -82,7 +111,8 @@ class RWKVWRAPPER(HFLM):
prefill_token = prefill_ids[i : i + CHUNK_SIZE] prefill_token = prefill_ids[i : i + CHUNK_SIZE]
_, state = self.model(prefill_token, state) _, state = self.model(prefill_token, state)
gen_length = context.shape[1] - max_length # hack: self.gen_len is set in tok_batch_encode
gen_length = self.gen_len
for i in range(gen_length): for i in range(gen_length):
logits, state = self.model([next_token], state) logits, state = self.model([next_token], state)
next_token = torch.argmax(logits, dim=-1) next_token = torch.argmax(logits, dim=-1)
...@@ -114,3 +144,18 @@ class RWKVWRAPPER(HFLM): ...@@ -114,3 +144,18 @@ class RWKVWRAPPER(HFLM):
use_cache=True, use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
def tok_batch_encode(
self,
strings: List[str],
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.gen_len = self.max_length - left_truncate_len
encoding = self.tokenizer(
strings,
truncation=truncation,
return_tensors="pt",
)
return encoding["input_ids"], encoding["attention_mask"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment