Commit 3285f030 authored by Baber's avatar Baber
Browse files

refactor filter hf to use new output classes

parent 451e73f1
......@@ -44,7 +44,7 @@ class FilterEnsemble:
def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
resps, docs = list([r.text] for y in resps for r in y), list(docs)
for f in self.filters:
# apply filters in sequence
......
......@@ -12,6 +12,13 @@ class GenerateInput:
gen_kwargs: dict
multimodal_arg: Optional[dict] = None
def __iter__(self):
return (
iter((self.prompt, self.gen_kwargs))
if not self.multimodal_arg
else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
)
@dataclass
class GenerateOutput:
......
......@@ -1321,8 +1321,8 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(req[0])
return -len(toks), req[0]
toks = self.tok_encode(req.prompt)
return -len(toks), req.prompt
pbar = tqdm(
total=len(requests),
......@@ -1358,7 +1358,7 @@ class HFLM(TemplateLM):
[reg.args for reg in requests],
sort_fn=_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
group_fn=lambda x: x.gen_kwargs,
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment