"docs/source/vscode:/vscode.git/clone" did not exist on "f98b745a81df1613a9c5f1d5986456663f86c457"
Commit 0cce16be authored by Baber's avatar Baber
Browse files

fix gen

parent 2b072879
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import lm_eval.models.utils import lm_eval.models.utils
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
...@@ -93,7 +94,14 @@ class RWKVWRAPPER(HFLM): ...@@ -93,7 +94,14 @@ class RWKVWRAPPER(HFLM):
self._model = RWKV(model=f"rwkv_model/{pretrained}", strategy="cuda fp16") self._model = RWKV(model=f"rwkv_model/{pretrained}", strategy="cuda fp16")
self._model.tie_weights = lambda: None self._model.tie_weights = lambda: None
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(
self,
context: "torch.tensor",
max_length: int,
stop: list[str],
**generation_kwargs,
) -> "torch.tensor":
context_len = context.shape[1]
remove_arg = ( remove_arg = (
["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"] ["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"]
) )
...@@ -118,7 +126,10 @@ class RWKVWRAPPER(HFLM): ...@@ -118,7 +126,10 @@ class RWKVWRAPPER(HFLM):
next_token = torch.argmax(logits, dim=-1) next_token = torch.argmax(logits, dim=-1)
all_outputs.append(next_token) all_outputs.append(next_token)
return torch.stack(all_outputs).unsqueeze(0) # return context + gen (context gets trimmed downstream)
return F.pad(
torch.stack(all_outputs).to("cpu"), (context_len, 0)
).unsqueeze(0)
else: else:
stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
self.tokenizer, self.tokenizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment