Commit 3c805576 authored by lintangsutawika's avatar lintangsutawika
Browse files

reformat for pre-commit

parent 6b375468
......@@ -87,7 +87,7 @@ Our intended output is for the model to predict a single whitespace, and then th
doc_to_target: "{{answer}}"
gold_alias: "{{answer}}"
```
where `doc_to_target` is *the string that will be appended to inputs for each few-shot example*, and `gold_alias` is *what is passed to our metric function as reference or gold answer to score against*. For example, for GSM8k word problems, `doc_to_target` should be the reference text reasoning chain given in the dataset culminating in the answer, and `gold_alias` should be **only the numeric answer** to the word problem that is given at the end of the reasoning chain, and which the evaluated model's answer will be compared against.
where `doc_to_target` is *the string that will be appended to inputs for each few-shot example*, and `gold_alias` is *what is passed to our metric function as reference or gold answer to score against*. For example, for GSM8k word problems, `doc_to_target` should be the reference text reasoning chain given in the dataset culminating in the answer, and `gold_alias` should be **only the numeric answer** to the word problem that is given at the end of the reasoning chain, and which the evaluated model's answer will be compared against.
**Important**: We always add one whitespace between the input and output, such that the full input-output string is `doc_to_target(doc) + " " + doc_to_text(doc)`. doc_to_text and doc_to_target should not contain trailing right or left whitespace, respectively.
......
......@@ -23,7 +23,7 @@ class HFLM(LM):
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
dtype: Optional[Union[str, torch.dtype]]="auto",
dtype: Optional[Union[str, torch.dtype]] = "auto",
subfolder=None,
tokenizer=None,
batch_size=1,
......@@ -59,8 +59,8 @@ class HFLM(LM):
revision = revision + ("/" + subfolder if subfolder is not None else "")
self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained,
revision=revision,
pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=utils.get_dtype(dtype),
).to(self.device)
......
......@@ -421,13 +421,11 @@ def clear_torch_cache():
torch.cuda.empty_cache()
def get_dtype(
dtype: Union[str, torch.dtype]
) -> torch.dtype:
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
\ No newline at end of file
return _torch_dtype
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment