Unverified Commit 35a65ba0 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

add utils.clear_torch_cache() (#1142)

parent b0d155d3
...@@ -2,6 +2,7 @@ import argparse ...@@ -2,6 +2,7 @@ import argparse
import numpy as np import numpy as np
import lm_eval.evaluator import lm_eval.evaluator
from lm_eval import tasks from lm_eval import tasks
from lm_eval import utils
import scipy.stats import scipy.stats
from typing import Tuple, Dict, List from typing import Tuple, Dict, List
import pandas as pd import pandas as pd
...@@ -9,7 +10,13 @@ import torch ...@@ -9,7 +10,13 @@ import torch
import os import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
eval_logger = lm_eval.utils.eval_logger eval_logger = utils.eval_logger
def memory_stats():
eval_logger.info(
f"Memory allocated: {torch.cuda.memory_allocated() / 1024 ** 2}, reserved: {torch.cuda.memory_reserved() // 1024 ** 2}"
)
def calculate_z_value(res1: Dict, res2: Dict) -> Tuple[float, float]: def calculate_z_value(res1: Dict, res2: Dict) -> Tuple[float, float]:
...@@ -103,7 +110,10 @@ if __name__ == "__main__": ...@@ -103,7 +110,10 @@ if __name__ == "__main__":
device=args.device, device=args.device,
batch_size=args.batch, batch_size=args.batch,
) )
torch.cuda.empty_cache() memory_stats()
utils.clear_torch_cache()
eval_logger.info("Memory stats cleared")
memory_stats()
results_hf = lm_eval.evaluator.simple_evaluate( results_hf = lm_eval.evaluator.simple_evaluate(
model="hf", model="hf",
model_args=f"pretrained={args.pretrained}" + hf_args, model_args=f"pretrained={args.pretrained}" + hf_args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment