Unverified Commit a963ab81 authored by Charles Lovering's avatar Charles Lovering Committed by GitHub
Browse files

Merge pull request #21 from tomlimi/tomlimi/fix_fs_write_out

Fixed issue with write_out for datasets without a training split
parents 256de63b 49f0699c
...@@ -915,12 +915,12 @@ class PromptSourceTask(Task): ...@@ -915,12 +915,12 @@ class PromptSourceTask(Task):
if num_fewshot == 0: if num_fewshot == 0:
labeled_examples = "" labeled_examples = ""
fewshotex, fewshotidx, fewshotsource = [], [], None fewshotex, fewshotidx, self.fewshotsource = [], [], None
else: else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc* # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs(): if self.has_training_docs():
fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd) fewshotex, fewshotidx = self.fewshot_examples(k=num_fewshot, rnd=rnd)
fewshotsource = "train" self.fewshotsource = "train"
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list( self._fewshot_docs = list(
...@@ -929,18 +929,18 @@ class PromptSourceTask(Task): ...@@ -929,18 +929,18 @@ class PromptSourceTask(Task):
else self.test_docs() else self.test_docs()
) )
if self.has_validation_docs(): if self.has_validation_docs():
fewshotsource = "val" self.fewshotsource = "val"
elif self.test_docs(): elif self.test_docs():
fewshotsource = "test" self.fewshotsource = "test"
fewshotex, fewshotidx = self._get_fewshot_examples( fewshotex, fewshotidx = self._get_fewshot_examples(
self._fewshot_docs, k=num_fewshot + 1, rnd=rnd self._fewshot_docs, k=num_fewshot + 1, rnd=rnd
) )
fewshotex, fewshotidx = [ fewshotex, fewshotidx = zip(*[
(shot, idx) (shot, idx)
for shot, idx in zip(fewshotex, fewshotidx) for shot, idx in zip(fewshotex, fewshotidx)
if shot != doc if shot != doc
] ])
# get rid of the doc that's the one we're evaluating, if it's in the fewshot # get rid of the doc that's the one we're evaluating, if it's in the fewshot
fewshotex, fewshotidx = ( fewshotex, fewshotidx = (
fewshotex[:num_fewshot], fewshotex[:num_fewshot],
...@@ -966,7 +966,7 @@ class PromptSourceTask(Task): ...@@ -966,7 +966,7 @@ class PromptSourceTask(Task):
ctx, ctx,
{ {
"fewshot_idx": fewshotidx, "fewshot_idx": fewshotidx,
"fewshot_source": fewshotsource, "fewshot_source": self.fewshotsource,
"fewshot_num": num_fewshot, "fewshot_num": num_fewshot,
"ctx": ctx, "ctx": ctx,
}, },
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment