Unverified Commit 5da33f87 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[modeling utils] revamp `from_pretrained(..., low_cpu_mem_usage=True)` + tests (#16657)

* add low_cpu_mem_usage tests

* wip: revamping

* wip

* install /usr/bin/time

* wip

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* fix assert

* put the wrapper back

* cleanup; switch to bert-base-cased

* Trigger CI

* Trigger CI
parent ce2fef2a
...@@ -217,7 +217,7 @@ jobs: ...@@ -217,7 +217,7 @@ jobs:
keys: keys:
- v0.4-torch-{{ checksum "setup.py" }} - v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }} - v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng - run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm] - run: pip install .[sklearn,torch,testing,sentencepiece,torch-speech,vision,timm]
- run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html - run: pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.11.0+cpu.html
......
...@@ -400,6 +400,95 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): ...@@ -400,6 +400,95 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
return error_msgs return error_msgs
def find_submodule_and_param_name(model, long_key, start_prefix):
"""
A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed
from the start of the key
"""
if len(start_prefix) > 0 and long_key.startswith(start_prefix):
long_key = ".".join(long_key.split(".")[1:])
split_key = long_key.split(".")
submodule = model
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
if submodule == model:
submodule = None
return submodule, split_key[0]
def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
"""
Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
# meta device was added in pt=1.9
require_version_core("torch>=1.9")
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
def _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the
params back to the normal device, but only for `loaded_state_dict_keys`.
`start_prefix` is used for models which insert their name into model keys, e.g. `bert` in
`bert.pooler.dense.weight`
"""
# XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model
# - deepspeed zero 3 support
# - need to copy metadata if any - see _load_state_dict_into_model
# - handling error_msgs - mimicking the error handling in module._load_from_state_dict()
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot currently be used with DeepSpeed ZeRO-3")
error_msgs = []
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k, start_prefix)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
return error_msgs
class ModuleUtilsMixin: class ModuleUtilsMixin:
""" """
A few utilities for `torch.nn.Modules`, to be used as a mixin. A few utilities for `torch.nn.Modules`, to be used as a mixin.
...@@ -1529,7 +1618,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1529,7 +1618,24 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
>>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
>>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
>>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True) >>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True)
```""" ```
* `low_cpu_mem_usage` algorithm:
This is an experimental function that loads the model using ~1x model size CPU memory
Here is how it works:
1. save which state_dict keys we have
2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory
3. after the model has been instantiated switch to the meta device all params/buffers that
are going to be replaced from the loaded state_dict
4. load state_dict 2nd time
5. replace the params/buffers from the state_dict
Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors
"""
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None) state_dict = kwargs.pop("state_dict", None)
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
...@@ -1778,6 +1884,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1778,6 +1884,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if not is_sharded and state_dict is None: if not is_sharded and state_dict is None:
# Time to load the checkpoint # Time to load the checkpoint
state_dict = load_state_dict(resolved_archive_file) state_dict = load_state_dict(resolved_archive_file)
# set dtype to instantiate the model under: # set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype # 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
...@@ -1801,13 +1908,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1801,13 +1908,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
dtype_orig = cls._set_default_torch_dtype(torch_dtype) dtype_orig = cls._set_default_torch_dtype(torch_dtype)
if low_cpu_mem_usage:
# save the keys
if is_sharded: if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else: else:
loaded_state_dict_keys = [k for k in state_dict.keys()] loaded_state_dict_keys = [k for k in state_dict.keys()]
del state_dict # free CPU memory - will reload again later if low_cpu_mem_usage:
state_dict = None
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
...@@ -1825,11 +1931,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1825,11 +1931,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
with no_init_weights(_enable=_fast_init): with no_init_weights(_enable=_fast_init):
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
if from_pt:
# restore default dtype
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if from_tf: if from_tf:
if resolved_archive_file.endswith(".index"): if resolved_archive_file.endswith(".index"):
# Load from a TensorFlow 1.X checkpoint - provided by original authors # Load from a TensorFlow 1.X checkpoint - provided by original authors
...@@ -1859,17 +1960,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1859,17 +1960,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise raise
elif from_pt: elif from_pt:
if low_cpu_mem_usage: # restore default dtype
cls._load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file) if dtype_orig is not None:
else: torch.set_default_dtype(dtype_orig)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model, model,
state_dict, state_dict,
loaded_state_dict_keys, # XXX: rename?
resolved_archive_file, resolved_archive_file,
pretrained_model_name_or_path, pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes, ignore_mismatched_sizes=ignore_mismatched_sizes,
sharded_metadata=sharded_metadata, sharded_metadata=sharded_metadata,
_fast_init=_fast_init, _fast_init=_fast_init,
low_cpu_mem_usage=low_cpu_mem_usage,
) )
# make sure token embedding weights are still tied if needed # make sure token embedding weights are still tied if needed
...@@ -1894,16 +1998,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1894,16 +1998,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cls, cls,
model, model,
state_dict, state_dict,
loaded_keys,
resolved_archive_file, resolved_archive_file,
pretrained_model_name_or_path, pretrained_model_name_or_path,
ignore_mismatched_sizes=False, ignore_mismatched_sizes=False,
sharded_metadata=None, sharded_metadata=None,
_fast_init=True, _fast_init=True,
low_cpu_mem_usage=False,
): ):
# Retrieve missing & unexpected_keys # Retrieve missing & unexpected_keys
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys()) expected_keys = list(model_state_dict.keys())
loaded_keys = list(state_dict.keys()) if state_dict is not None else sharded_metadata["all_checkpoint_keys"]
prefix = model.base_model_prefix prefix = model.base_model_prefix
def _fix_key(key): def _fix_key(key):
...@@ -1994,9 +2099,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1994,9 +2099,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
) )
del state_dict[checkpoint_key] del state_dict[checkpoint_key]
return mismatched_keys return mismatched_keys
if low_cpu_mem_usage:
model_state_dict = None # free references to model's params to allow memory freeing
_move_model_to_meta(model, loaded_keys, start_prefix)
if state_dict is not None: if state_dict is not None:
# Whole checkpoint # Whole checkpoint
mismatched_keys = _find_mismatched_keys( mismatched_keys = _find_mismatched_keys(
...@@ -2009,7 +2117,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2009,7 +2117,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
) )
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
else: else:
# Sharded checkpoint # Sharded checkpoint or whole but low_cpu_mem_usage==True
# This should always be a list but, just to be sure. # This should always be a list but, just to be sure.
if not isinstance(resolved_archive_file, list): if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file] resolved_archive_file = [resolved_archive_file]
...@@ -2018,6 +2127,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2018,6 +2127,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
mismatched_keys = [] mismatched_keys = []
for shard_file in resolved_archive_file: for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file) state_dict = load_state_dict(shard_file)
if low_cpu_mem_usage:
model_state_dict = model.state_dict()
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model. # matching the weights in the model.
mismatched_keys += _find_mismatched_keys( mismatched_keys += _find_mismatched_keys(
...@@ -2028,6 +2141,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2028,6 +2141,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
remove_prefix_from_model, remove_prefix_from_model,
ignore_mismatched_sizes, ignore_mismatched_sizes,
) )
if low_cpu_mem_usage:
error_msgs += _load_state_dict_into_meta_model(
model_to_load, state_dict, loaded_keys, start_prefix
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
if len(error_msgs) > 0: if len(error_msgs) > 0:
...@@ -2093,13 +2212,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2093,13 +2212,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return retrieved_modules return retrieved_modules
@staticmethod @staticmethod
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file): def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
""" """
This is an experimental function that loads the model using ~1.x model size CPU memory This is an experimental function that loads the model using ~1.x model size CPU memory
Before it gets called we do: Before you call it do:
1. save which state_dict keys we have 1. save which state_dict keys are available
2. drop state_dict before model is created, since the latter takes 1x model size memory 2. drop state_dict before model is created, since the latter takes 1x model size memory
Here then we continue: Here then we continue:
...@@ -2110,58 +2229,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2110,58 +2229,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
""" """
require_version_core("torch>=1.9")
if is_deepspeed_zero3_enabled():
raise ValueError("low_cpu_mem_usage arg cannot be used with DeepSpeed ZeRO-3")
# a helper util to find the last sub-module and the param/buffer name _move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
def find_submodule_and_param_name(model, long_key): state_dict = load_state_dict(resolved_archive_file)
split_key = long_key.split(".") error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
submodule = model return error_msgs
while len(split_key) > 1:
if hasattr(submodule, split_key[0]):
submodule = getattr(submodule, split_key[0])
del split_key[0]
else:
submodule = None
break
return submodule, split_key[0]
# dematerialize param storage for keys that are going to be replaced by state_dict, by
# putting those on the meta device
for k in loaded_state_dict_keys:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
# selectively switch to the meta device only those params/buffers that will
# be next replaced from state_dict. This a complex way to do p.to_("meta")
# since we have no in-place to_ for tensors.
new_val = getattr(submodule, param_name)
if isinstance(new_val, torch.nn.Parameter):
# isinstance returns False for Params on meta device, so switch after the check
new_val = torch.nn.Parameter(new_val.to("meta"))
else:
new_val = new_val.to("meta")
setattr(submodule, param_name, new_val)
# only now can load state_dict(s)
if not isinstance(resolved_archive_file, list):
resolved_archive_file = [resolved_archive_file]
for archive_file in resolved_archive_file:
state_dict = torch.load(archive_file, map_location="cpu")
# materialize state_dict entries one by one on CPU
for k in loaded_state_dict_keys:
if k in state_dict:
submodule, param_name = find_submodule_and_param_name(model, k)
if submodule is not None:
param_dtype = getattr(submodule, param_name).dtype
new_val = state_dict[k].to(param_dtype)
if isinstance(getattr(submodule, param_name), torch.nn.Parameter):
new_val = torch.nn.Parameter(new_val)
setattr(submodule, param_name, new_val)
del state_dict
@classmethod @classmethod
def register_for_auto_class(cls, auto_class="AutoModel"): def register_for_auto_class(cls, auto_class="AutoModel"):
......
...@@ -17,6 +17,7 @@ import inspect ...@@ -17,6 +17,7 @@ import inspect
import logging import logging
import os import os
import re import re
import shlex
import shutil import shutil
import sys import sys
import tempfile import tempfile
...@@ -667,6 +668,20 @@ def require_librosa(test_case): ...@@ -667,6 +668,20 @@ def require_librosa(test_case):
return test_case return test_case
def cmd_exists(cmd):
return shutil.which(cmd) is not None
def require_usr_bin_time(test_case):
"""
Decorator marking a test that requires `/usr/bin/time`
"""
if not cmd_exists("/usr/bin/time"):
return unittest.skip("test requires /usr/bin/time")(test_case)
else:
return test_case
def get_gpu_count(): def get_gpu_count():
""" """
Return the number of available gpus (regardless of whether torch, tf or jax is used) Return the number of available gpus (regardless of whether torch, tf or jax is used)
...@@ -1178,6 +1193,39 @@ class TestCasePlus(unittest.TestCase): ...@@ -1178,6 +1193,39 @@ class TestCasePlus(unittest.TestCase):
return tmp_dir return tmp_dir
def python_one_liner_max_rss(self, one_liner_str):
"""
Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
program.
Args:
one_liner_str (`string`):
a python one liner code that gets passed to `python -c`
Returns:
max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
Requirements:
this helper needs `/usr/bin/time` to be installed (`apt install time`)
Example:
```
one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")'
max_rss = self.python_one_liner_max_rss(one_liner_str)
```
"""
if not cmd_exists("/usr/bin/time"):
raise ValueError("/usr/bin/time is required, install with `apt install time`")
cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
with CaptureStd() as cs:
execute_subprocess_async(cmd, env=self.get_env())
# returned data is in KB so convert to bytes
max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
return max_rss
def tearDown(self): def tearDown(self):
# get_auto_remove_tmp_dir feature: remove registered temp dirs # get_auto_remove_tmp_dir feature: remove registered temp dirs
......
...@@ -52,6 +52,7 @@ from transformers.testing_utils import ( ...@@ -52,6 +52,7 @@ from transformers.testing_utils import (
is_staging_test, is_staging_test,
require_torch, require_torch,
require_torch_multi_gpu, require_torch_multi_gpu,
require_usr_bin_time,
slow, slow,
torch_device, torch_device,
) )
...@@ -2489,6 +2490,56 @@ class ModelUtilsTest(TestCasePlus): ...@@ -2489,6 +2490,56 @@ class ModelUtilsTest(TestCasePlus):
for p1, p2 in zip(model.parameters(), ref_model.parameters()): for p1, p2 in zip(model.parameters(), ref_model.parameters()):
self.assertTrue(torch.allclose(p1, p2)) self.assertTrue(torch.allclose(p1, p2))
def test_from_pretrained_low_cpu_mem_usage_functional(self):
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
# sharded models
mnames = [
"hf-internal-testing/tiny-random-bert-sharded",
"hf-internal-testing/tiny-random-bert",
]
for mname in mnames:
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
@require_usr_bin_time
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
mname = "bert-base-cased"
preamble = "from transformers import AutoModel"
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
# print(f"{max_rss_normal=}")
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
# print(f"{max_rss_low_mem=}")
diff_bytes = max_rss_normal - max_rss_low_mem
diff_percent = diff_bytes / max_rss_low_mem
# print(f"{diff_bytes=}, {diff_percent=}")
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
# it's at least 15% less cpu memory consumed
self.assertGreater(
diff_percent,
0.15,
"should use less CPU memory for low_cpu_mem_usage=True, "
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
)
# if you want to compare things manually, let's first look at the size of the model in bytes
# model = BertModel.from_pretrained(mname, low_cpu_mem_usage=False)
# total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
# total_bytes = total_numel * 4 # 420MB
# Now the diff_bytes should be very close to total_bytes, but the reports are inconsistent.
# The easiest way to test this is to switch the model and torch.load to do all the work on
# gpu - that way one can measure exactly the total and peak memory used. Perhaps once we add
# functionality to load models directly on gpu, this test can be rewritten to use torch's
# cuda memory tracking and then we should be able to do a much more precise test.
def test_cached_files_are_used_when_internet_is_down(self): def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment