Unverified Commit bd7d265a authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Fix unintuitive `--gen_kwargs` behavior (#1329)

* don't override do_sample if no value for it is passed

* Update gen_kwargs override condition

* Update huggingface.py

* Update huggingface.py

* run linters

* silence an erroneous warning
parent 1554066c
......@@ -140,9 +140,12 @@ def simple_evaluate(
)
else:
default_num_fewshot = config["num_fewshot"]
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
if default_num_fewshot:
# warn a user, if a specific num_fewshot > 0 was specified.
# if unspecified in config, no warning message
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_obj._config["num_fewshot"] = num_fewshot
......
......@@ -705,10 +705,12 @@ class HFLM(LM):
return self.model(inps).logits
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs:
generation_kwargs["do_sample"] = False
# if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
do_sample = generation_kwargs.get("do_sample", None)
if do_sample is False and "temperature" == 0.0:
generation_kwargs.pop("temperature", 0.0)
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, context.shape[1], context.shape[0]
......
......@@ -24,6 +24,7 @@ generation_kwargs:
- "\n\n"
- "Question:"
do_sample: false
temperature: 0.0
repeats: 1
num_fewshot: 5
filter_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment