Commit 0cce16be authored by Baber's avatar Baber
Browse files

fix gen

parent 2b072879
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import lm_eval.models.utils
from lm_eval.api.registry import register_model
......@@ -93,7 +94,14 @@ class RWKVWRAPPER(HFLM):
self._model = RWKV(model=f"rwkv_model/{pretrained}", strategy="cuda fp16")
self._model.tie_weights = lambda: None
def _model_generate(self, context, max_length, stop, **generation_kwargs):
def _model_generate(
self,
context: "torch.tensor",
max_length: int,
stop: list[str],
**generation_kwargs,
) -> "torch.tensor":
context_len = context.shape[1]
remove_arg = (
["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"]
)
......@@ -118,7 +126,10 @@ class RWKVWRAPPER(HFLM):
next_token = torch.argmax(logits, dim=-1)
all_outputs.append(next_token)
return torch.stack(all_outputs).unsqueeze(0)
# return context + gen (context gets trimmed downstream)
return F.pad(
torch.stack(all_outputs).to("cpu"), (context_len, 0)
).unsqueeze(0)
else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment