f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
)
lr_scheduler=self.lr_scheduler_builder(optimizer)
return{
"optimizer":optimizer,
"lr_scheduler":{
"scheduler":lr_scheduler,
"interval":"step",
},
}
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
defget_batch_logps(
self,
logits:torch.FloatTensor,
labels:torch.LongTensor,
average_log_prob:bool=False,
)->torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
assertlogits.shape[:-1]==labels.shape
labels=labels.clone()
loss_mask=labels!=-100
# dummy token; we'll ignore the losses on these tokens later
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1]/audio_time:.2f}"