Commit df94cfdd authored by Baber's avatar Baber
Browse files

nit

parent 4ee812a0
......@@ -90,9 +90,8 @@ class RWKVWRAPPER(HFLM):
local_dir="rwkv_model",
)
self._model = RWKV(
model=f"rwkv_model/{pretrained}", strategy=f"cuda {dtype}"
)
self._model = RWKV(model=f"rwkv_model/{pretrained}", strategy="cuda fp16")
self._model.tie_weights = lambda: None
def _model_generate(self, context, max_length, stop, **generation_kwargs):
remove_arg = (
......@@ -105,6 +104,7 @@ class RWKVWRAPPER(HFLM):
all_outputs = []
if not self.is_hf:
CHUNK_SIZE = 4096
context = context.squeeze()
prefill_ids, next_token = context[:-1], context[-1]
state = None
for i in range(0, len(prefill_ids), CHUNK_SIZE):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment