Commit 103a10e3 authored by baberabb's avatar baberabb
Browse files

bugfix

parent a2920e3d
......@@ -80,6 +80,8 @@ class VLLM(LM):
stop: Optional[List[str]] = None,
**kwargs,
):
if "do_sample" in kwargs.keys():
kwargs.pop("do_sample")
if generate:
generate_sampling_params = SamplingParams(
max_tokens=max_tokens, stop=stop, **kwargs
......@@ -181,7 +183,8 @@ class VLLM(LM):
fn=None,
)
for chunk in chunks:
context, context_encoding, all_gen_kwargs = zip(*chunk)
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = context_and_encoding
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment