Commit d03c9fde authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add the hack (works for Mistral/Llama, destroys performance for GPT2

parent 68c30aa7
...@@ -792,6 +792,26 @@ class HFLM(LM): ...@@ -792,6 +792,26 @@ class HFLM(LM):
# context_enc = self.tok_encode(context, add_special_tokens=False) # context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len = len(context_enc) context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:] continuation_enc = whole_enc[context_enc_len:]
# quite the hack, but what this does:
# circumvents the addition of an extraneous sentencepiece underline token
# that was produced when passing " <word>" into the Llama / Mistral tokenizer.
# if instead we pass "<word>" in, we don't get this extra token (29871 for Llama.)
# which would hurt performance if provided.
if (
len(continuation.lstrip()) + 1 == len(continuation)
and continuation.startswith(" ")
) or (len(continuation_enc) == 0):
context_enc_2 = context_enc
continuation_enc_2 = self.tok_encode(
continuation[1:], add_special_tokens=False
)
# assert context_enc == context_enc_2
# assert continuation_enc == continuation_enc_2, f"{continuation_enc},{continuation_enc_2}"
return context_enc_2, continuation_enc_2
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment