Unverified Commit 9bb0de9c authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #390 from sxjscience/master

Enable "low_cpu_mem_usage" to reduce the memory usage of HF models
parents 3d14707a 859f96fd
import transformers
import torch import torch
import transformers
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
...@@ -9,6 +9,7 @@ class HFLM(BaseLM): ...@@ -9,6 +9,7 @@ class HFLM(BaseLM):
device="cuda", device="cuda",
pretrained="gpt2", pretrained="gpt2",
revision="main", revision="main",
low_cpu_mem_usage=None,
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
...@@ -37,8 +38,7 @@ class HFLM(BaseLM): ...@@ -37,8 +38,7 @@ class HFLM(BaseLM):
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage
revision=revision,
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
......
...@@ -7,6 +7,8 @@ import inspect ...@@ -7,6 +7,8 @@ import inspect
import sys import sys
from typing import List from typing import List
from omegaconf import OmegaConf
class ExitCodeError(Exception): class ExitCodeError(Exception):
pass pass
...@@ -27,10 +29,7 @@ def simple_parse_args_string(args_string): ...@@ -27,10 +29,7 @@ def simple_parse_args_string(args_string):
if not args_string: if not args_string:
return {} return {}
arg_list = args_string.split(",") arg_list = args_string.split(",")
args_dict = {} args_dict = OmegaConf.to_object(OmegaConf.from_dotlist(arg_list))
for arg in arg_list:
k, v = arg.split("=")
args_dict[k] = v
return args_dict return args_dict
......
...@@ -25,6 +25,7 @@ setuptools.setup( ...@@ -25,6 +25,7 @@ setuptools.setup(
"jsonlines", "jsonlines",
"numexpr", "numexpr",
"openai>=0.6.4", "openai>=0.6.4",
"omegaconf>=2.2",
"pybind11>=2.6.2", "pybind11>=2.6.2",
"pycountry", "pycountry",
"pytablewriter", "pytablewriter",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment