"models/vision/ddim/example.py" did not exist on "485797b846c16f917b46d59527c1c167116491a1"
Commit 13317f8c authored by Jon Tow's avatar Jon Tow
Browse files

Fix training docs caching

parent 7c9da714
...@@ -42,8 +42,8 @@ class ARCEasy(MultipleChoiceTask): ...@@ -42,8 +42,8 @@ class ARCEasy(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return map(self._convert_standard, self._training_docs) return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -52,7 +52,9 @@ class DROP(Task): ...@@ -52,7 +52,9 @@ class DROP(Task):
return False return False
def training_docs(self): def training_docs(self):
return map(self._convert_standard, self.dataset["train"]) if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -38,8 +38,8 @@ class HeadQABase(MultipleChoiceTask): ...@@ -38,8 +38,8 @@ class HeadQABase(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return map(self._convert_standard, self._training_docs) return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -43,8 +43,8 @@ class HellaSwag(MultipleChoiceTask): ...@@ -43,8 +43,8 @@ class HellaSwag(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return map(self._convert_standard, self._training_docs) return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -40,8 +40,8 @@ class MathQA(MultipleChoiceTask): ...@@ -40,8 +40,8 @@ class MathQA(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return map(self._convert_standard, self._training_docs) return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -43,8 +43,8 @@ class OpenBookQA(MultipleChoiceTask): ...@@ -43,8 +43,8 @@ class OpenBookQA(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return map(self._convert_standard, self._training_docs) return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -42,8 +42,8 @@ class PiQA(MultipleChoiceTask): ...@@ -42,8 +42,8 @@ class PiQA(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return map(self._convert_standard, self._training_docs) return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -40,7 +40,9 @@ class QuAC(Task): ...@@ -40,7 +40,9 @@ class QuAC(Task):
return False return False
def training_docs(self): def training_docs(self):
return map(self._convert_standard, self.dataset["train"]) if self._training_docs is None:
self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
...@@ -38,8 +38,8 @@ class SciQ(MultipleChoiceTask): ...@@ -38,8 +38,8 @@ class SciQ(MultipleChoiceTask):
def training_docs(self): def training_docs(self):
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.dataset["train"]) self._training_docs = list(map(self._convert_standard, self.dataset["train"]))
return map(self._convert_standard, self._training_docs) return self._training_docs
def validation_docs(self): def validation_docs(self):
return map(self._convert_standard, self.dataset["validation"]) return map(self._convert_standard, self.dataset["validation"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment