Commit 0112b0f0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2394 canceled with stages
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1024]
__set_seed2: !apply:numpy.random.seed [1024]
__set_seed3: !apply:torch.manual_seed [1024]
__set_seed4: !apply:torch.cuda.manual_seed_all [1024]
# fixed params
sample_rate: 24000
text_encoder_input_size: 512
llm_input_size: 1536
llm_output_size: 1536
basemodel_path: '../../pretrained_models/InspireMusic-1.5B-Long/'
generator_path: '../../pretrained_models/InspireMusic-1.5B-Long/music_tokenizer'
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:inspiremusic.llm.llm.LLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
audio_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
text_encoder_conf:
name: "none"
llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder
input_size: !ref <text_encoder_input_size>
pretrain_path: !ref <basemodel_path>
sampling: !name:inspiremusic.utils.common.topk_sampling
top_k: 350
train_cfg_ratio: 0.2
infer_cfg_ratio: 3.0
flow: !new:inspiremusic.flow.flow.MaskedDiff
input_size: 256
output_size: 80
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 75
only_mask_loss: True
encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 4
linear_units: 1024
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 256
use_cnn_module: False
macaron_style: False
length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator
channels: 512
sampling_ratios: [1, 1, 1, 1]
decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM
in_channels: 240
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder
in_channels: 1024
out_channels: 512
channels: [256, 256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 8
num_heads: 8
act_fn: 'gelu'
generator_model_dir: !ref <generator_path>
hift: !new:inspiremusic.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator
# processor functions
parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener
get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer
tokenizer_path: !ref <basemodel_path>
tokenizer_name: "qwen-2.5"
allowed_special: 'all'
tokenize: !name:inspiremusic.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:inspiremusic.dataset.processor.filter
max_length: 28000
min_length: 0
token_max_length: 200
token_min_length: 1
resample: !name:inspiremusic.dataset.processor.resample
resample_rate: !ref <sample_rate>
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 128
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 24000
center: False
compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding
normalize: True
shuffle: !name:inspiremusic.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:inspiremusic.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:inspiremusic.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 10000 # llm 12000
padding: !name:inspiremusic.dataset.processor.padding
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <shuffle>,
!ref <sort>,
!ref <filter>,
!ref <batch>,
!ref <padding>,
]
# train conf
train_conf:
optim: adam
optim_conf:
lr: 0.0001 # change to 0.001 if you want to train flow from scratch
scheduler: warmuplr
scheduler_conf:
warmup_steps: 5000
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: 500
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1024]
__set_seed2: !apply:numpy.random.seed [1024]
__set_seed3: !apply:torch.manual_seed [1024]
__set_seed4: !apply:torch.cuda.manual_seed_all [1024]
# fixed params
sample_rate: 24000
target_sample_rate: 48000
text_encoder_input_size: 512
llm_input_size: 896
llm_output_size: 896
basemodel_path: '../../pretrained_models/InspireMusic-Base/'
generator_path: '../../pretrained_models/InspireMusic-Base/music_tokenizer'
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:inspiremusic.llm.llm.LLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
audio_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
text_encoder_conf:
name: "none"
llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder
input_size: !ref <text_encoder_input_size>
pretrain_path: !ref <basemodel_path>
sampling: !name:inspiremusic.utils.common.topk_sampling
top_k: 350
train_cfg_ratio: 0.2
infer_cfg_ratio: 3.0
flow: !new:inspiremusic.flow.flow.MaskedDiff
input_size: 256
output_size: 80
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 75
only_mask_loss: True
encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 4
linear_units: 1024
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 256
use_cnn_module: False
macaron_style: False
length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator
channels: 512
sampling_ratios: [1, 1, 1, 1]
decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM
in_channels: 240
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder
in_channels: 1024
out_channels: 512
channels: [256, 256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 8
num_heads: 8
act_fn: 'gelu'
generator_model_dir: !ref <generator_path>
hift: !new:inspiremusic.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator
# processor functions
parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener
get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer
tokenizer_path: !ref <basemodel_path>
tokenizer_name: "qwen-2.0"
allowed_special: 'all'
tokenize: !name:inspiremusic.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:inspiremusic.dataset.processor.filter
max_length: 20000
min_length: 1
token_max_length: 200
token_min_length: 1
max_acoustic_length: 20000
min_acoustic_length: 1800
mode: 'train_flow'
resample: !name:inspiremusic.dataset.processor.resample
resample_rate: !ref <sample_rate>
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 128
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 24000
center: False
compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding
normalize: True
shuffle: !name:inspiremusic.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:inspiremusic.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:inspiremusic.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 15500 # llm 12000
# batch_type: 'static'
# batch_size: 2 # llm 12000
padding: !name:inspiremusic.dataset.processor.padding
mode: 'train'
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <shuffle>,
!ref <sort>,
!ref <filter>,
!ref <batch>,
!ref <padding>,
]
# train conf
train_conf:
optim: adam
optim_conf:
lr: 0.0001 # change to 0.001 if you want to train flow from scratch
scheduler: warmuplr
scheduler_conf:
warmup_steps: 500
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: 500
# set random seed, so that you may reproduce your result.
__set_seed1: !apply:random.seed [1024]
__set_seed2: !apply:numpy.random.seed [1024]
__set_seed3: !apply:torch.manual_seed [1024]
__set_seed4: !apply:torch.cuda.manual_seed_all [1024]
# fixed params
sample_rate: 24000
target_sample_rate: 48000
text_encoder_input_size: 512
llm_input_size: 896
llm_output_size: 896
basemodel_path: '../../pretrained_models/InspireMusic-Base-24kHz/'
generator_path: '../../pretrained_models/InspireMusic-Base-24kHz/music_tokenizer'
# model params
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
# for system/third_party class/function, we do not require this.
llm: !new:inspiremusic.llm.llm.LLM
text_encoder_input_size: !ref <text_encoder_input_size>
llm_input_size: !ref <llm_input_size>
llm_output_size: !ref <llm_output_size>
audio_token_size: 4096
length_normalized_loss: True
lsm_weight: 0
text_encoder_conf:
name: "none"
llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder
input_size: !ref <text_encoder_input_size>
pretrain_path: !ref <basemodel_path>
sampling: !name:inspiremusic.utils.common.topk_sampling
top_k: 350
train_cfg_ratio: 0.2
infer_cfg_ratio: 3.0
flow: !new:inspiremusic.flow.flow.MaskedDiff
input_size: 256
output_size: 80
output_type: 'mel'
vocab_size: 4096
input_frame_rate: 75
only_mask_loss: True
encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder
output_size: 512
attention_heads: 4
linear_units: 1024
num_blocks: 3
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
normalize_before: True
input_layer: 'linear'
pos_enc_layer_type: 'rel_pos_espnet'
selfattention_layer_type: 'rel_selfattn'
input_size: 256
use_cnn_module: False
macaron_style: False
length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator
channels: 512
sampling_ratios: [1, 1, 1, 1]
decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM
in_channels: 240
cfm_params: !new:omegaconf.DictConfig
content:
sigma_min: 1e-06
solver: 'euler'
t_scheduler: 'cosine'
training_cfg_rate: 0.2
inference_cfg_rate: 0.7
reg_loss_type: 'l1'
estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder
in_channels: 1024
out_channels: 512
channels: [256, 256]
dropout: 0.0
attention_head_dim: 64
n_blocks: 4
num_mid_blocks: 8
num_heads: 8
act_fn: 'gelu'
generator_model_dir: !ref <generator_path>
hift: !new:inspiremusic.hifigan.generator.HiFTGenerator
in_channels: 80
base_channels: 512
nb_harmonics: 8
sampling_rate: !ref <sample_rate>
nsf_alpha: 0.1
nsf_sigma: 0.003
nsf_voiced_threshold: 10
upsample_rates: [8, 8]
upsample_kernel_sizes: [16, 16]
istft_params:
n_fft: 16
hop_len: 4
resblock_kernel_sizes: [3, 7, 11]
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
source_resblock_kernel_sizes: [7, 11]
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
lrelu_slope: 0.1
audio_limit: 0.99
f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor
num_class: 1
in_channels: 80
cond_channels: 512
wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator
# processor functions
parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener
get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer
tokenizer_path: !ref <basemodel_path>
tokenizer_name: "qwen-2.0"
allowed_special: 'all'
tokenize: !name:inspiremusic.dataset.processor.tokenize
get_tokenizer: !ref <get_tokenizer>
allowed_special: !ref <allowed_special>
filter: !name:inspiremusic.dataset.processor.filter
max_length: 20000
min_length: 1
token_max_length: 200
token_min_length: 1
max_acoustic_length: 20000
min_acoustic_length: 1800
mode: 'train_flow'
resample: !name:inspiremusic.dataset.processor.resample
resample_rate: !ref <sample_rate>
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
n_fft: 1024
num_mels: 128
sampling_rate: !ref <sample_rate>
hop_size: 256
win_size: 1024
fmin: 0
fmax: 24000
center: False
compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank
feat_extractor: !ref <feat_extractor>
parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding
normalize: True
shuffle: !name:inspiremusic.dataset.processor.shuffle
shuffle_size: 1000
sort: !name:inspiremusic.dataset.processor.sort
sort_size: 500 # sort_size should be less than shuffle_size
batch: !name:inspiremusic.dataset.processor.batch
batch_type: 'dynamic'
max_frames_in_batch: 15500 # llm 12000
# batch_type: 'static'
# batch_size: 2 # llm 12000
padding: !name:inspiremusic.dataset.processor.padding
mode: 'train'
# dataset processor pipeline
data_pipeline: [
!ref <parquet_opener>,
!ref <tokenize>,
!ref <shuffle>,
!ref <sort>,
!ref <filter>,
!ref <batch>,
!ref <padding>,
]
# train conf
train_conf:
optim: adam
optim_conf:
lr: 0.0001 # change to 0.001 if you want to train flow from scratch
scheduler: warmuplr
scheduler_conf:
warmup_steps: 500
max_epoch: 200
grad_clip: 5
accum_grad: 2
log_interval: 100
save_per_step: 500
electro_1 <|90.00|><|chorus|><|A dynamic blend of electronic beats and drum and bass rhythms.|><|120.00|>
jazz_1 <|30.00|><|verse1|><|A smooth blend of contemporary jazz with soulful undertones, evoke a relaxed and sophisticated atmosphere.|><|60.00|>
instrumental_1 <|0.00|><|intro|><|A soothing piano instrumental with a melancholic feel, evoke a sense of longing, complemented by light and serene instrumental solos.|><|30.00|>
\ No newline at end of file
electro_1 dataset/example/electro_1.wav
jazz_1 dataset/example/jazz_1.wav
instrumental_1 dataset/example/instrumental_1.wav
\ No newline at end of file
data/samples/parquet/parquet_000000000.tar
1 <|30.0|><|verse|><|Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.|><|60.0|>
2 <|0.0|><|intro|><|A delightful collection of classical keyboard music, purely instrumental, exuding a timeless and elegant charm.|><|30.0|>
3 <|120.0|><|chorus|><|The instrumental rap track exudes a classic boom bap vibe, characterized by its French hip-hop roots and a smooth, rhythmic flow.|><|150.0|>
4 <|300.0|><|outro|><|The music exudes a vibrant and sophisticated jazz ambiance, characterized by the rich, dynamic sounds of a big band ensemble. With instrumental purity and a touch of classical influence, it offers a captivating listening experience.|><|330.0|>
\ No newline at end of file
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
. ./path.sh || exit 1;
export TOKENIZERS_PARALLELISM=False
model_name="InspireMusic-Base"
pretrained_model_dir=../../pretrained_models/${model_name}
dataset_name=samples
# inference normal mode
echo "Run inference."
expr_name="inspiremusic_${dataset_name}"
for task in 'text-to-music' 'continuation'; do
python inspiremusic/bin/inference.py --task $task \
--gpu 0 \
--config conf/inspiremusic.yaml \
--prompt_data data/${dataset_name}/parquet/data.list \
--flow_model $pretrained_model_dir/flow.pt \
--llm_model $pretrained_model_dir/llm.pt \
--music_tokenizer $pretrained_model_dir/music_tokenizer \
--wavtokenizer $pretrained_model_dir/wavtokenizer \
--chorus verse \
--output_sample_rate 48000 \
--min_generate_audio_seconds 5.0 \
--max_generate_audio_seconds 30.0 \
--result_dir `pwd`/exp/${model_name}/${task}_${expr_name}
# if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz
# --output_sample_rate 24000 \
# use fast inference mode
# --fast # fast mode without flow matching
echo `pwd`/exp/${model_name}/${task}_${expr_name}
done
#!/bin/bash
# Copyright 2024 Alibaba Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
. ./path.sh || exit 1;
export TOKENIZERS_PARALLELISM=False
model_name="InspireMusic-1.5B-Long"
pretrained_model_dir=../../pretrained_models/${model_name}
dataset_name=samples
# inference normal mode
echo "Run inference."
expr_name="inspiremusic_${dataset_name}"
for task in 'text-to-music' 'continuation'; do
python inspiremusic/bin/inference.py --task $task \
--gpu 0 \
--config conf/inspiremusic_1.5b_long.yaml \
--prompt_data data/${dataset_name}/parquet/data.list \
--flow_model $pretrained_model_dir/flow.pt \
--llm_model $pretrained_model_dir/llm.pt \
--music_tokenizer $pretrained_model_dir/music_tokenizer \
--wavtokenizer $pretrained_model_dir/wavtokenizer \
--chorus default \
--output_sample_rate 48000 \
--min_generate_audio_seconds 5.0 \
--max_generate_audio_seconds 300.0 \
--result_dir `pwd`/exp/${model_name}/${task}_${expr_name}
# if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz
# --output_sample_rate 24000 \
# use fast inference mode
# --fast # fast mode without flow matching
echo `pwd`/exp/${model_name}/${task}_${expr_name}
done
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from inspiremusic.cli.inspiremusic import InspireMusic
def get_args():
parser = argparse.ArgumentParser(description='export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/InspireMusic',
help='local path')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False)
# 1. export llm text_encoder
llm_text_encoder = inspiremusic.model.llm.text_encoder.half()
script = torch.jit.script(llm_text_encoder)
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
# 2. export llm llm
llm_llm = inspiremusic.model.llm.llm.half()
script = torch.jit.script(llm_llm)
script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
script = torch.jit.optimize_for_inference(script)
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
# 3. export flow encoder
flow_encoder = inspiremusic.model.flow.encoder
script = torch.jit.script(flow_encoder)
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
if __name__ == '__main__':
main()
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import onnxruntime
import random
import torch
from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from inspiremusic.cli.inspiremusic import InspireMusic
def get_dummy_input(batch_size, seq_len, out_channels, device):
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
t = torch.rand((batch_size), dtype=torch.float32, device=device)
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
return x, mask, mu, t, spks, cond
def get_args():
parser = argparse.ArgumentParser(description='export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/InspireMusic',
help='local path')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False)
# 1. export flow decoder estimator
estimator = inspiremusic.model.flow.decoder.estimator
device = inspiremusic.model.device
batch_size, seq_len = 1, 256
out_channels = inspiremusic.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
torch.onnx.export(
estimator,
(x, mask, mu, t, spks, cond),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
output_names=['estimator_out'],
dynamic_axes={
'x': {0: 'batch_size', 2: 'seq_len'},
'mask': {0: 'batch_size', 2: 'seq_len'},
'mu': {0: 'batch_size', 2: 'seq_len'},
'cond': {0: 'batch_size', 2: 'seq_len'},
't': {0: 'batch_size'},
'spks': {0: 'batch_size'},
'estimator_out': {0: 'batch_size', 2: 'seq_len'},
}
)
# 2. test computation consistency
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers)
for _ in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
output_pytorch = estimator(x, mask, mu, t, spks, cond)
ort_inputs = {
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy()
}
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
if __name__ == "__main__":
main()
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.dataset.dataset import Dataset
from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
def get_args():
parser = argparse.ArgumentParser(description='inference only with flow model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--flow_model', required=True, help='flow model file')
parser.add_argument('--llm_model', default=None,required=False, help='llm model file')
parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file')
parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
parser.add_argument('--sample_rate', type=int, default=48000, required=False,
help='sampling rate of generated audio')
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False,
help='the minimum generated audio length in seconds')
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
help='the maximum generated audio length in seconds')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
# Init inspiremusic models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = InspireMusicModel(None, configs['flow'], configs['hift'], configs['wavtokenizer'])
model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
if args.llm_model is None:
model.llm = None
else:
model.llm = model.llm.to(torch.float32)
if args.flow_model is None:
model.flow = None
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
if "semantic_token" in batch:
token = batch["semantic_token"].to(device)
token_len = batch["semantic_token_len"].to(device)
else:
if audio_token is None:
token = None
token_len = None
else:
token = audio_token.view(audio_token.size(0),-1,4)[:,:,0]
token_len = audio_token_len / 4
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
text = batch["text"]
if "time_start" not in batch.keys():
batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64)
if "time_end" not in batch.keys():
batch["time_end"] = torch.randint(args.min_generate_audio_seconds, args.max_generate_audio_seconds, (1,)).to(torch.float64)
elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds:
batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
if "chorus" not in batch.keys():
batch["chorus"] = torch.randint(1, 5, (1,))
if args.chorus == "random":
batch["chorus"] = torch.randint(1, 5, (1,))
elif args.chorus == "intro":
batch["chorus"] = torch.Tensor([0])
elif "verse" in args.chorus:
batch["chorus"] = torch.Tensor([1])
elif args.chorus == "chorus":
batch["chorus"] = torch.Tensor([2])
elif args.chorus == "outro":
batch["chorus"] = torch.Tensor([4])
time_start = batch["time_start"].to(device)
time_end = batch["time_end"].to(device)
chorus = batch["chorus"].to(torch.int)
text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
chorus = chorus.to(device)
model_input = {"text": text, "audio_token": token, "audio_token_len": token_len,
"text_token": text_token, "text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus], "raw_text":text}
music_audios = []
for model_output in model.inference(**model_input):
music_audios.append(model_output['music_audio'])
music_key = utts[0]
music_fn = os.path.join(args.result_dir, '{}.wav'.format(music_key))
torchaudio.save(music_fn, music_audios[0], sample_rate=args.sample_rate)
f.write('{} {}\n'.format(music_key, music_fn))
f.flush()
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
main()
\ No newline at end of file
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import torch
from torch.utils.data import DataLoader
import torchaudio
from hyperpyyaml import load_hyperpyyaml
from tqdm import tqdm
from inspiremusic.cli.model import InspireMusicModel
from inspiremusic.dataset.dataset import Dataset
import time
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def get_args():
parser = argparse.ArgumentParser(description='inference only with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--prompt_data', required=True, help='prompt data file')
parser.add_argument('--flow_model', default=None, required=False, help='flow model file')
parser.add_argument('--llm_model', default=None,required=False, help='flow model file')
parser.add_argument('--music_tokenizer', required=True, help='music tokenizer model file')
parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.')
parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model')
parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio')
parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds')
parser.add_argument('--trim', default=False, type=bool, required=False, help='trim the silence ending of generated audio')
parser.add_argument('--format', type=str, default="wav", required=False,
choices=["wav", "mp3", "m4a", "flac"],
help='sampling rate of input audio')
parser.add_argument('--sample_rate', type=int, default=24000, required=False,
help='sampling rate of input audio')
parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000],
help='sampling rate of generated output audio')
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False,
help='the minimum generated audio length in seconds')
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
help='the maximum generated audio length in seconds')
parser.add_argument('--gpu',
type=int,
default=0,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--task',
default='text-to-music',
choices=['text-to-music', 'continuation', "reconstruct", "super_resolution"],
help='choose inference task type. text-to-music: text-to-music task. continuation: music continuation task. reconstruct: reconstruction of original music. super_resolution: convert original 24kHz music into 48kHz music.')
parser.add_argument('--result_dir', required=True, help='asr result file')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
if args.fast:
args.output_sample_rate = 24000
min_generate_audio_length = int(args.output_sample_rate * args.min_generate_audio_seconds)
max_generate_audio_length = int(args.output_sample_rate * args.max_generate_audio_seconds)
assert args.min_generate_audio_seconds <= args.max_generate_audio_seconds
# Init inspiremusic models from configs
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f)
model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.fast, args.fp16)
model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
if args.llm_model is None:
model.llm = None
else:
model.llm = model.llm.to(torch.float32)
if args.flow_model is None:
model.flow = None
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=True, partition=False)
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
del configs
os.makedirs(args.result_dir, exist_ok=True)
fn = os.path.join(args.result_dir, 'wav.scp')
f = open(fn, 'w')
caption_fn = os.path.join(args.result_dir, 'captions.txt')
caption_f = open(caption_fn, 'w')
with torch.no_grad():
for _, batch in tqdm(enumerate(test_data_loader)):
utts = batch["utts"]
assert len(utts) == 1, "inference mode only support batchsize 1"
text_token = batch["text_token"].to(device)
text_token_len = batch["text_token_len"].to(device)
if "time_start" not in batch.keys():
batch["time_start"] = torch.randint(0, args.min_generate_audio_seconds, (1,)).to(torch.float64)
if batch["time_start"].numpy()[0] > 300:
batch["time_start"] = torch.Tensor([0]).to(torch.float64)
if "time_end" not in batch.keys():
batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
else:
if (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) < args.min_generate_audio_seconds:
batch["time_end"] = torch.randint(int(batch["time_start"].numpy()[0] + args.min_generate_audio_seconds), int(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds), (1,)).to(torch.float64)
elif (batch["time_end"].numpy()[0] - batch["time_start"].numpy()[0]) > args.max_generate_audio_seconds:
batch["time_end"] = torch.Tensor([(batch["time_start"].numpy()[0] + args.max_generate_audio_seconds)]).to(torch.float64)
if "chorus" not in batch.keys():
batch["chorus"] = torch.randint(1, 5, (1,))
if args.chorus == "random":
batch["chorus"] = torch.randint(1, 5, (1,))
elif args.chorus == "intro":
batch["chorus"] = torch.Tensor([0])
elif "verse" in args.chorus:
batch["chorus"] = torch.Tensor([1])
elif args.chorus == "chorus":
batch["chorus"] = torch.Tensor([2])
elif args.chorus == "outro":
batch["chorus"] = torch.Tensor([4])
else:
batch["chorus"] = batch["chorus"]
time_start = batch["time_start"].to(device)
time_end = batch["time_end"].to(device)
chorus = batch["chorus"].to(torch.int)
text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
chorus = chorus.to(device)
if batch["acoustic_token"] is None:
audio_token = None
audio_token_len = None
else:
audio_token = batch["acoustic_token"].to(device)
audio_token_len = batch["acoustic_token_len"].to(device)
text = batch["text"]
if "semantic_token" in batch:
token = batch["semantic_token"].to(device)
token_len = batch["semantic_token_len"].to(device)
else:
if audio_token is None:
token = None
token_len = None
else:
token = audio_token.view(audio_token.size(0), -1, 4)[:, :, 0]
token_len = audio_token_len / 4
if args.task in ['text-to-music', 'continuation']:
# text to music, music continuation
model_input = {"text": text, "audio_token": token,
"audio_token_len": token_len,
"text_token": text_token,
"text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus],
"raw_text": text,
"sample_rate": args.output_sample_rate,
"duration_to_gen": args.max_generate_audio_seconds,
"task": args.task}
elif args.task in ['reconstruct', 'super_resolution']:
# audio reconstruction, audio super resolution
model_input = {"text": text, "audio_token": audio_token,
"audio_token_len": audio_token_len,
"text_token": text_token,
"text_token_len": text_token_len,
"embeddings": [time_start, time_end, chorus],
"raw_text": text,
"sample_rate": args.output_sample_rate,
"duration_to_gen": args.max_generate_audio_seconds,
"task": args.task}
else:
# zero-shot
model_input = {'text' : text,
'text_len' : text_token_len,
'prompt_text' : text_token,
'prompt_text_len' : text_token_len,
'llm_prompt_audio_token' : token,
'llm_prompt_audio_token_len' : token_len,
'flow_prompt_audio_token' : audio_token,
'flow_prompt_audio_token_len': audio_token_len,
'prompt_audio_feat' : audio_feat,
'prompt_audio_feat_len' : audio_feat_len,
"embeddings" : [time_start,
time_end,
chorus]}
music_key = utts[0]
music_audios = []
music_fn = os.path.join(args.result_dir, f'{music_key}.{args.format}')
bench_start = time.time()
for model_output in model.inference(**model_input):
music_audios.append(model_output['music_audio'])
bench_end = time.time()
if args.trim:
music_audio = trim_audio(music_audios[0],
sample_rate=args.output_sample_rate,
threshold=0.05,
min_silence_duration=0.8)
else:
music_audio = music_audios[0]
if music_audio.shape[0] != 0:
if music_audio.shape[1] > max_generate_audio_length:
music_audio = music_audio[:, :max_generate_audio_length]
if music_audio.shape[1] >= min_generate_audio_length:
try:
if args.fade_out:
music_audio = fade_out(music_audio, args.output_sample_rate, args.fade_out_duration)
music_audio = music_audio.repeat(2, 1)
if args.format in ["wav", "flac"]:
torchaudio.save(music_fn, music_audio, sample_rate=args.output_sample_rate, encoding="PCM_S", bits_per_sample=24)
elif args.format in ["mp3", "m4a"]:
torchaudio.backend.sox_io_backend.save(filepath=music_fn, src=music_audio, sample_rate=args.output_sample_rate, format=args.format)
else:
logging.info(f"Format is not supported. Please choose from wav, mp3, m4a, flac.")
except Exception as e:
logging.info(f"Error saving file: {e}")
raise
audio_duration = music_audio.shape[1] / args.output_sample_rate
rtf = (bench_end - bench_start) / audio_duration
logging.info(f"processing time: {int(bench_end - bench_start)}s, audio length: {int(audio_duration)}s, rtf: {rtf}, text prompt: {text_prompt}")
f.write('{} {}\n'.format(music_key, music_fn))
f.flush()
caption_f.write('{}\t{}\n'.format(music_key, text_prompt))
caption_f.flush()
else:
logging.info(f"Generate audio length {music_audio.shape[1]} is shorter than min_generate_audio_length.")
else:
logging.info(f"Generate audio is empty, dim = {music_audio.shape[0]}.")
f.close()
logging.info('Result wav.scp saved in {}'.format(fn))
if __name__ == '__main__':
main()
# Copyright (c) 2024 Alibaba Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import torch
import torch.distributed as dist
import deepspeed
import glob
import os
from hyperpyyaml import load_hyperpyyaml
from torch.cuda.amp import GradScaler, autocast
from torch.distributed.elastic.multiprocessing.errors import record
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from inspiremusic.utils.executor import Executor
from inspiremusic.utils.train_utils import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='number of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=True,
help='Use pinned memory buffers used for reading')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=30,
type=int,
help='timeout (in seconds) of inspiremusic_join.')
parser.add_argument('--fp16',
action='store_true',
default=False,
help='Enable fp16 mixed precision training')
parser.add_argument('--lora',
action='store_true',
default=False,
help='Enable LoRA training')
parser.add_argument('--lora_rank',
default=4,
type=int,
help='LoRA rank')
parser.add_argument('--lora_alpha',
default=16,
type=int,
help='LoRA alpha')
parser.add_argument('--lora_dropout',
default=0.1,
type=float,
help='LoRA dropout rate')
parser.add_argument('--lora_target_modules',
nargs='+',
default=["k_proj","v_proj"],
help='Target modules to apply LoRA (e.g., ["q_proj", "v_proj"])')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
@record
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
configs['train_conf'].update(vars(args))
# Init env for ddp
init_distributed(args)
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
# Tensorboard summary
writer = init_summarywriter(args)
# load checkpoint
model = configs[args.model]
if args.checkpoint is not None:
model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
else:
# Find and load the latest checkpoint
checkpoint_files = glob.glob(os.path.join(args.model_dir, '*.pt'))
if checkpoint_files:
latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
logging.info(f"Loaded latest checkpoint from {latest_checkpoint}")
model.load_state_dict(torch.load(latest_checkpoint, map_location='cpu'))
if args.lora:
logging.info("Applying LoRA to the model...")
if not args.lora_target_modules:
raise ValueError("No target modules specified for LoRA. Please provide --lora_target_modules.")
lora_config = LoraConfig(
task_type="CAUSAL_LM", # Change to appropriate task type
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=args.lora_target_modules
)
model.llm.model = get_peft_model(model.llm.model, lora_config)
# Optionally freeze the base model
else:
logging.info("LoRA is not enabled. Training the full model.")
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
# Get optimizer & scheduler
model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
# Initialize AMP for torch_ddp if fp16 is enabled
scaler = None
if args.fp16:
scaler = GradScaler()
logging.info("Initialized AMP GradScaler for mixed precision training.")
# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
# Get executor
executor = Executor()
# Start training loop
for epoch in range(info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
executor.train_one_epoch(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=scaler)
dist.destroy_process_group(group_join)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment