Commit 3285f030 authored by Baber's avatar Baber
Browse files

refactor filter hf to use new output classes

parent 451e73f1
...@@ -44,7 +44,7 @@ class FilterEnsemble: ...@@ -44,7 +44,7 @@ class FilterEnsemble:
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: List[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs) resps, docs = list([r.text] for y in resps for r in y), list(docs)
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
......
...@@ -12,6 +12,13 @@ class GenerateInput: ...@@ -12,6 +12,13 @@ class GenerateInput:
gen_kwargs: dict gen_kwargs: dict
multimodal_arg: Optional[dict] = None multimodal_arg: Optional[dict] = None
def __iter__(self):
return (
iter((self.prompt, self.gen_kwargs))
if not self.multimodal_arg
else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
)
@dataclass @dataclass
class GenerateOutput: class GenerateOutput:
......
...@@ -1321,8 +1321,8 @@ class HFLM(TemplateLM): ...@@ -1321,8 +1321,8 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make # padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement # automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
toks = self.tok_encode(req[0]) toks = self.tok_encode(req.prompt)
return -len(toks), req[0] return -len(toks), req.prompt
pbar = tqdm( pbar = tqdm(
total=len(requests), total=len(requests),
...@@ -1358,7 +1358,7 @@ class HFLM(TemplateLM): ...@@ -1358,7 +1358,7 @@ class HFLM(TemplateLM):
[reg.args for reg in requests], [reg.args for reg in requests],
sort_fn=_collate, sort_fn=_collate,
group_by="gen_kwargs", group_by="gen_kwargs",
group_fn=lambda x: x[1], group_fn=lambda x: x.gen_kwargs,
) )
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment