Unverified Commit 0b92ae34 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Add offload support to Bark (#25037)



* initial Bark offload proposal

* use hooks instead of manually offloading

* add test of bark offload to cpu feature

* Apply nit suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update docstrings of offload
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unecessary set_seed in Bark tests

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent 9cea3e7b
...@@ -23,8 +23,13 @@ from torch.nn import functional as F ...@@ -23,8 +23,13 @@ from torch.nn import functional as F
from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor from ...generation.logits_process import AlternatingCodebooksLogitsProcessor, SuppressTokensLogitsProcessor
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
logging,
)
from ..auto import AutoModel from ..auto import AutoModel
from .configuration_bark import ( from .configuration_bark import (
BarkCoarseConfig, BarkCoarseConfig,
...@@ -288,6 +293,26 @@ class BarkPreTrainedModel(PreTrainedModel): ...@@ -288,6 +293,26 @@ class BarkPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
# if has _hf_hook, has been offloaded so the device has to be found in the hook
if not hasattr(self, "_hf_hook"):
return get_parameter_device(self)
for module in self.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return get_parameter_device(self)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel): if isinstance(module, BarkCausalModel) or isinstance(module, BarkFineModel) or isinstance(module, BarkModel):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -1376,6 +1401,63 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1376,6 +1401,63 @@ class BarkModel(BarkPreTrainedModel):
self.config = config self.config = config
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
# for bark_model, device must be verified on its sub-models
# if has _hf_hook, has been offloaded so the device has to be found in the hook
if not hasattr(self.semantic, "_hf_hook"):
return get_parameter_device(self)
for module in self.semantic.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
def enable_cpu_offload(self, gpu_id: Optional[int] = 0):
r"""
Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
method moves one whole sub-model at a time to the GPU when it is used, and the sub-model remains in GPU until
the next sub-model runs.
Args:
gpu_id (`int`, *optional*, defaults to 0):
GPU id on which the sub-models will be loaded and offloaded.
"""
if is_accelerate_available():
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu")
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
# this layer is used outside the first foward pass of semantic so need to be loaded before semantic
self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
hook = None
for cpu_offloaded_model in [
self.semantic,
self.coarse_acoustics,
self.fine_acoustics,
]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
self.fine_acoustics_hook = hook
_, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.codec_model_hook = hook
def codec_decode(self, fine_output): def codec_decode(self, fine_output):
"""Turn quantized audio codes into audio array using encodec.""" """Turn quantized audio codes into audio array using encodec."""
...@@ -1490,9 +1572,20 @@ class BarkModel(BarkPreTrainedModel): ...@@ -1490,9 +1572,20 @@ class BarkModel(BarkPreTrainedModel):
**kwargs_fine, **kwargs_fine,
) )
if getattr(self, "fine_acoustics_hook", None) is not None:
# Manually offload fine_acoustics to CPU
# and load codec_model to GPU
# since bark doesn't use codec_model forward pass
self.fine_acoustics_hook.offload()
self.codec_model = self.codec_model.to(self.device)
# 4. Decode the output and generate audio array # 4. Decode the output and generate audio array
audio = self.codec_decode(output) audio = self.codec_decode(output)
if getattr(self, "codec_model_hook", None) is not None:
# Offload codec_model to CPU
self.codec_model_hook.offload()
return audio return audio
def can_generate(self) -> bool: def can_generate(self) -> bool:
......
...@@ -31,7 +31,7 @@ from transformers.models.bark.generation_configuration_bark import ( ...@@ -31,7 +31,7 @@ from transformers.models.bark.generation_configuration_bark import (
BarkFineGenerationConfig, BarkFineGenerationConfig,
BarkSemanticGenerationConfig, BarkSemanticGenerationConfig,
) )
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, require_torch_gpu, slow, torch_device
from transformers.utils import cached_property from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
...@@ -989,3 +989,42 @@ class BarkModelIntegrationTests(unittest.TestCase): ...@@ -989,3 +989,42 @@ class BarkModelIntegrationTests(unittest.TestCase):
coarse_temperature=0.2, coarse_temperature=0.2,
fine_temperature=0.1, fine_temperature=0.1,
) )
@require_torch_gpu
@slow
def test_generate_end_to_end_with_offload(self):
input_ids = self.inputs
with torch.no_grad():
# standard generation
output_with_no_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
torch.cuda.empty_cache()
memory_before_offload = torch.cuda.memory_allocated()
model_memory_footprint = self.model.get_memory_footprint()
# activate cpu offload
self.model.enable_cpu_offload()
memory_after_offload = torch.cuda.memory_allocated()
# checks if the model have been offloaded
# CUDA memory usage after offload should be near 0, leaving room to small differences
room_for_difference = 1.1
self.assertGreater(
(memory_before_offload - model_memory_footprint) * room_for_difference, memory_after_offload
)
# checks if device is the correct one
self.assertEqual(self.model.device.type, torch_device)
# checks if hooks exist
self.assertTrue(hasattr(self.model.semantic, "_hf_hook"))
# output with cpu offload
output_with_offload = self.model.generate(**input_ids, do_sample=False, fine_temperature=None)
# checks if same output
self.assertListEqual(output_with_no_offload.tolist(), output_with_offload.tolist())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment