Commit 41f807a8 authored by Baber's avatar Baber
Browse files

fix: bug in acc_mutual_info slicing; add `target_delimiter` to uncond choices

parent 9d29ef0e
......@@ -1481,7 +1481,10 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
aux_arguments = [("", f"{choice}") for choice in choices]
# TODO: should these be strided? will have to modify the processing in process_results if so
aux_arguments = [
("", f"{target_delimiter}{choice}") for choice in choices
]
arguments.extend(aux_arguments)
......@@ -1580,11 +1583,12 @@ class ConfigurableTask(Task):
):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
# as we extend the args list with unconditional ("", continuation) pairs
lls_unconditional = lls[len(choices) :]
if len(lls_unconditional) != len(choices):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
lls = lls[: len(choices)]
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment