Commit 452d49e0 authored by baberabb's avatar baberabb
Browse files

fix `add_special_tokens=False` in generate_until + decoding fixes

parent 1c62da1d
...@@ -139,6 +139,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -139,6 +139,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
if "do_sample" in kwargs.keys(): if "do_sample" in kwargs.keys():
kwargs.pop("do_sample") kwargs.pop("do_sample")
if generate: if generate:
# hf defaults
kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False)
kwargs["spaces_between_special_tokens"] = kwargs.get(
"spaces_between_special_tokens", False
)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else: else:
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -226,7 +231,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -226,7 +231,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
# batch tokenize contexts # batch tokenize contexts
context, all_gen_kwargs = zip(*(req.args for req in requests)) context, all_gen_kwargs = zip(*(req.args for req in requests))
context_encoding = self.tokenizer(context).input_ids context_encoding = self.tokenizer(context, add_special_tokens=False).input_ids
requests = [ requests = [
((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs) ((a, b), c) for a, b, c in zip(context, context_encoding, all_gen_kwargs)
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment