Commit 22f5ad26 authored by Baber's avatar Baber
Browse files

add rwkv

parent bb9ed305
from typing import Optional, Union
import torch
import lm_eval.models.utils
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
@register_model("rwkv")
class MambaLMWrapper(HFLM):
def __init__(
self,
pretrained="",
# To use the HF compatible variant
is_hf: bool = False,
**kwargs,
) -> None:
if "backend" in kwargs:
assert kwargs["backend"] == "causal"
self.is_hf = is_hf or (True if pretrained.endswith("hf") else False)
self.tokenizer = kwargs["tokenizer"]
self.pretrained = pretrained
super().__init__(
pretrained=pretrained,
# set appropriate defaults for tokenizer, max length, etc
backend=kwargs.pop("backend", "causal"),
tokenizer=self.tokenizer,
max_length=kwargs.pop("max_length", 4096),
**kwargs,
)
def _get_config(
self,
pretrained: str,
**kwargs,
) -> None:
if self.is_hf:
super()._get_config(pretrained, **kwargs)
else:
self._config = {}
def _create_model(
self,
pretrained: str,
dtype: Optional[Union[str, torch.dtype]] = "fp16",
**kwargs,
) -> None:
if self.is_hf:
super()._create_model(pretrained, dtype=dtype, **kwargs)
else:
try:
from rwkv.model import RWKV
except ModuleNotFoundError as exception:
raise type(exception)(
"attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \
please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`",
)
import os
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = "1"
os.environ["RWKV_V7_ON"] = "1"
self._model = RWKV(model=self.pretrained, strategy=f"cuda {dtype}")
def _model_generate(self, context, max_length, stop, **generation_kwargs):
remove_arg = (
["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"]
)
for key in remove_arg:
if key in generation_kwargs:
generation_kwargs.pop(key)
all_outputs = []
if not self.is_hf:
CHUNK_SIZE = 4096
prefill_ids, next_token = context[:-1], context[-1]
state = None
for i in range(0, len(prefill_ids), CHUNK_SIZE):
prefill_token = prefill_ids[i : i + CHUNK_SIZE]
_, state = self.model(prefill_token, state)
gen_length = context.shape[1] - max_length
for i in range(gen_length):
logits, state = self.model([next_token], state)
next_token = torch.argmax(logits, dim=-1)
all_outputs.append(next_token)
return torch.cat(all_outputs)
else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer,
stop,
context.shape[1],
context.shape[0],
)
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
return self.model.generate(
input_ids=context,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment