"examples/research_projects/colossalai/requirement.txt" did not exist on "c2283310688ff75e8fb4be3d9938ed0818cb038d"
Commit 03611f97 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

optimize GPU memory usage

parent d112db94
...@@ -70,15 +70,13 @@ AutoModel.register(DACConfig, DACModel) ...@@ -70,15 +70,13 @@ AutoModel.register(DACConfig, DACModel)
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import set_seed from accelerate.utils import set_seed
from accelerate.utils.memory import release_memory
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask
if is_wandb_available(): if is_wandb_available():
from wandb import Audio from wandb import Audio
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.38.0.dev0") check_min_version("4.38.0.dev0")
...@@ -1122,11 +1120,13 @@ def main(): ...@@ -1122,11 +1120,13 @@ def main():
cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8) cosine_sim = torch.nn.functional.cosine_similarity(audio_features, text_features, dim=1, eps=1e-8)
clap.to("cpu")
clap_inputs.to("cpu")
return cosine_sim.mean() return cosine_sim.mean()
def wer(prompts, audios, device): def wer(prompts, audios, device):
asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device) asr_pipeline = pipeline(model="distil-whisper/distil-large-v2", device=device)
transcriptions = asr_pipeline([{'raw': audio, 'sampling_rate': sampling_rate} for audio in audios]) transcriptions = asr_pipeline([{'raw': audio, 'sampling_rate': sampling_rate} for audio in audios], batch_size=int(training_args.per_device_eval_batch_size))
word_error = 100 * metric.compute(predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts]) word_error = 100 * metric.compute(predictions=[t["text"].lower() for t in transcriptions], references=[t.lower() for t in prompts])
...@@ -1418,6 +1418,9 @@ def main(): ...@@ -1418,6 +1418,9 @@ def main():
eval_descriptions = [] eval_descriptions = []
eval_prompts = [] eval_prompts = []
eval_start = time.time() eval_start = time.time()
# release training input batch
batch = release_memory(batch)
validation_dataloader = DataLoader( validation_dataloader = DataLoader(
vectorized_datasets["eval"], vectorized_datasets["eval"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment