Commit df94cfdd authored by Baber's avatar Baber
Browse files

nit

parent 4ee812a0
...@@ -90,9 +90,8 @@ class RWKVWRAPPER(HFLM): ...@@ -90,9 +90,8 @@ class RWKVWRAPPER(HFLM):
local_dir="rwkv_model", local_dir="rwkv_model",
) )
self._model = RWKV( self._model = RWKV(model=f"rwkv_model/{pretrained}", strategy="cuda fp16")
model=f"rwkv_model/{pretrained}", strategy=f"cuda {dtype}" self._model.tie_weights = lambda: None
)
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
remove_arg = ( remove_arg = (
...@@ -105,6 +104,7 @@ class RWKVWRAPPER(HFLM): ...@@ -105,6 +104,7 @@ class RWKVWRAPPER(HFLM):
all_outputs = [] all_outputs = []
if not self.is_hf: if not self.is_hf:
CHUNK_SIZE = 4096 CHUNK_SIZE = 4096
context = context.squeeze()
prefill_ids, next_token = context[:-1], context[-1] prefill_ids, next_token = context[:-1], context[-1]
state = None state = None
for i in range(0, len(prefill_ids), CHUNK_SIZE): for i in range(0, len(prefill_ids), CHUNK_SIZE):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment