Unverified Commit 722eb395 authored by cdpath's avatar cdpath Committed by GitHub
Browse files

fix potential oom issue (#387)

parent b9b145c3
...@@ -141,6 +141,7 @@ class MDLRetriever(TopkRetriever): ...@@ -141,6 +141,7 @@ class MDLRetriever(TopkRetriever):
"""Retrieve the in-context example index for each test example.""" """Retrieve the in-context example index for each test example."""
return self.topk_search() return self.topk_search()
@torch.no_grad()
def cal_ce(self, input_texts: List[str], mask_length=None): def cal_ce(self, input_texts: List[str], mask_length=None):
if self.metric_model is None: if self.metric_model is None:
logger.info( logger.info(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment