Unverified Commit 450a181d authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

Add Pop2Piano (#21785)



* init commit

* config updated also some modeling

* Processor and Model config combined

* extraction pipeline(upto before spectogram & mel_conditioner) added but not properly tested

* model loading successful!

* feature extractor done!

* FE can now be called from HF

* postprocessing added in fe file

* same as prev commit

* Pop2PianoConfig doc done

* cfg docs slightly changed

* fe docs done

* batched

* batched working!

* temp

* v1

* checking

* trying to go with generate

* with generate and model tests passed

* before rebasing

* .

* tests done docs done remaining others & nits

* nits

* LogMelSpectogram shifted to FeatureExtractor

* is_tf rmeoved from pop2piano/init

* import solved

* tokenization tests added

* minor fixed regarding modeling_pop2piano

* tokenizer changed to only return midi_object and other changes

* Updated paper abstract(Camera-ready version) (#2)

* more comments and nits

* ruff changes

* code quality fix

* sg comments

* t5 change added and rebased

* comments except batching

* batching done

* comments

* small doc fix

* example removed from modeling

* ckpt

* forward it compatible with fe and generation done

* comments

* comments

* code-quality fix(maybe)

* ckpts changed

* doc file changed from mdx to md

* test fixes

* tokenizer test fix

* changes

* nits done main changes remaining

* code modified

* Pop2PianoProcessor added with tests

* other comments

* added Pop2PianoProcessor to dummy_objects

* added require_onnx to modeling file

* changes

* update .md file

* remove extra line in index.md

* back to the main index

* added pop2piano to index

* Added tokenizer.__call__ with valid args and batch_decode and aligned the processor part too

* changes

* added return types to 2 tokenizer methods

* the PR build test might work now

* added backends

* PR build fix

* vocab added

* comments

* refactored vocab into 1 file

* added conversion script

* comments

* essentia version changed in .md

* comments

* more tokenizer tests added

* minor fix

* tests extended for outputs acc check

* small fix

---------
Co-authored-by: default avatarJongho Choi <sweetcocoa@snu.ac.kr>
parent 6f041fcb
# coding=utf-8
# Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Pop2Piano model."""
import copy
import math
from typing import Optional, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils.checkpoint import checkpoint
from transformers.generation import GenerationConfig
from ...activations import ACT2FN
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
logging,
replace_return_docstrings,
)
from .configuration_pop2piano import Pop2PianoConfig
logger = logging.get_logger(__name__)
_load_pop2piano_layer_norm = True
try:
from apex.normalization import FusedRMSNorm
_load_pop2piano_layer_norm = False
logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNorm")
except ImportError:
# using the normal Pop2PianoLayerNorm
pass
except Exception:
logger.warning("Discovered apex but it failed to load, falling back to Pop2PianoLayerNorm")
pass
_CONFIG_FOR_DOC = "Pop2PianoConfig"
_CHECKPOINT_FOR_DOC = "sweetcocoa/pop2piano"
POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST = [
"sweetcocoa/pop2piano",
# See all Pop2Piano models at https://huggingface.co/models?filter=pop2piano
]
POP2PIANO_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings
so you should be able to pad the inputs on both the right and the left. Indices can be obtained using
[`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail.
[What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining
take a look a [Pop2Pianp Training](./Pop2Piano#training).
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
[`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the
starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
`decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
`[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Does the same task as `inputs_embeds`. If `inputs_embeds` is not present but `input_features` is present
then `input_features` will be considered as `inputs_embeds`.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
input (see `past_key_values`). This is useful if you want more control over how to convert
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If
`decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value of
`inputs_embeds`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pop2Piano
class Pop2PianoLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
# Pop2Piano uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
if not _load_pop2piano_layer_norm:
Pop2PianoLayerNorm = FusedRMSNorm # noqa
ALL_LAYERNORM_LAYERS.append(Pop2PianoLayerNorm)
# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Pop2Piano,t5->pop2piano
class Pop2PianoDenseActDense(nn.Module):
def __init__(self, config: Pop2PianoConfig):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pop2Piano
class Pop2PianoDenseGatedActDense(nn.Module):
def __init__(self, config: Pop2PianoConfig):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Pop2Piano
class Pop2PianoLayerFF(nn.Module):
def __init__(self, config: Pop2PianoConfig):
super().__init__()
if config.is_gated_act:
self.DenseReluDense = Pop2PianoDenseGatedActDense(config)
else:
self.DenseReluDense = Pop2PianoDenseActDense(config)
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoAttention(nn.Module):
def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = False
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
)
# Prune linear layers
self.q = prune_linear_layer(self.q, index)
self.k = prune_linear_layer(self.k, index)
self.v = prune_linear_layer(self.v, index)
self.o = prune_linear_layer(self.o, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.inner_dim = self.key_value_proj_dim * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
"""
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
"""
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length
if past_key_value is not None:
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states
# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
# get key/value states
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoLayerCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False)
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano
class Pop2PianoBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(Pop2PianoLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
if self.is_decoder:
self.layer.append(Pop2PianoLayerCrossAttention(config))
self.layer.append(Pop2PianoLayerFF(config))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs
else:
outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class Pop2PianoPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = Pop2PianoConfig
base_model_prefix = "transformer"
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = ["Pop2PianoBlock"]
_keep_in_fp32_modules = ["wo"]
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, Pop2PianoLayerNorm):
module.weight.data.fill_(factor * 1.0)
elif isinstance(module, Pop2PianoConcatEmbeddingToMel):
module.embedding.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, Pop2PianoForConditionalGeneration):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, Pop2PianoDenseActDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi, "bias") and module.wi.bias is not None:
module.wi.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, Pop2PianoDenseGatedActDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_()
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
module.wi_1.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, Pop2PianoAttention):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model
key_value_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Pop2PianoAttention, Pop2PianoStack)):
module.gradient_checkpointing = value
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
class Pop2PianoStack(Pop2PianoPreTrainedModel):
# Copied from transformers.models.t5.modeling_t5.T5Stack.__init__ with T5->Pop2Piano,t5->pop2piano
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.block = nn.ModuleList(
[Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
)
self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
# Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
def get_input_embeddings(self):
return self.embed_tokens
# Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True:
if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
position_bias = None
encoder_decoder_position_bias = None
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i]
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
# (cross-attention position bias), (cross-attention weights)
position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
class Pop2PianoConcatEmbeddingToMel(nn.Module):
"""Embedding Matrix for `composer` tokens."""
def __init__(self, config):
super().__init__()
self.embedding = nn.Embedding(num_embeddings=config.composer_vocab_size, embedding_dim=config.d_model)
def forward(self, feature, index_value, embedding_offset):
index_shifted = index_value - embedding_offset
composer_embedding = self.embedding(index_shifted).unsqueeze(1)
inputs_embeds = torch.cat([composer_embedding, feature], dim=1)
return inputs_embeds
Pop2Piano_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Pop2PianoConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings("""Pop2Piano Model with a `language modeling` head on top.""", Pop2Piano_START_DOCSTRING)
class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
def __init__(self, config: Pop2PianoConfig):
super().__init__(config)
self.config = config
self.model_dim = config.d_model
self.shared = nn.Embedding(config.vocab_size, config.d_model)
self.mel_conditioner = Pop2PianoConcatEmbeddingToMel(config)
encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = Pop2PianoStack(encoder_config, self.shared)
decoder_config = copy.deepcopy(config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = config.num_decoder_layers
self.decoder = Pop2PianoStack(decoder_config, self.shared)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def get_output_embeddings(self):
return self.lm_head
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_mel_conditioner_outputs(
self,
input_features: torch.FloatTensor,
composer: str,
generation_config: GenerationConfig,
attention_mask: torch.FloatTensor = None,
):
"""
This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
control the type of MIDI token generated by the model.
Args:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
input features extracted from the feature extractor.
composer (`str`):
composer token which determines the type of MIDI tokens to be generated.
generation_config (`~generation.GenerationConfig`):
The generation is used to get the composer-feature_token pair.
attention_mask (``, *optional*):
For batched generation `input_features` are padded to have the same shape across all examples.
`attention_mask` helps to determine which areas were padded and which were not.
- 1 for tokens that are **not padded**,
- 0 for tokens that are **padded**.
"""
composer_to_feature_token = generation_config.composer_to_feature_token
if composer not in composer_to_feature_token.keys():
raise ValueError(
f"Please choose a composer from {list(composer_to_feature_token.keys())}. Composer received - {composer}"
)
composer_value = composer_to_feature_token[composer]
composer_value = torch.tensor(composer_value, device=self.device)
composer_value = composer_value.repeat(input_features.shape[0])
embedding_offset = min(composer_to_feature_token.values())
input_features = self.mel_conditioner(
feature=input_features,
index_value=composer_value,
embedding_offset=embedding_offset,
)
if attention_mask is not None:
input_features[~attention_mask[:, 0].bool()] = 0.0
# since self.mel_conditioner adds a new array at the front of inputs_embeds we need to do the same for attention_mask to keep the shapes same
attention_mask = torch.concatenate([attention_mask[:, 0].view(-1, 1), attention_mask], axis=1)
return input_features, attention_mask
return input_features, None
@add_start_docstrings_to_model_forward(POP2PIANO_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
input_features: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
labels in `[0, ..., config.vocab_size]`
Returns:
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is not None and input_features is not None:
raise ValueError("Both `inputs_embeds` and `input_features` received! Please provide only one of them")
elif input_features is not None and inputs_embeds is None:
inputs_embeds = input_features
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
@torch.no_grad()
def generate(
self,
input_features,
attention_mask=None,
composer="composer1",
generation_config=None,
**kwargs,
):
"""
Generates token ids for midi outputs.
<Tip warning={true}>
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation
strategies and code examples, check out the [following guide](./generation_strategies).
</Tip>
Parameters:
input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`.
attention_mask:
For batched generation `input_features` are padded to have the same shape across all examples.
`attention_mask` helps to determine which areas were padded and which were not.
- 1 for tokens that are **not padded**,
- 0 for tokens that are **padded**.
composer (`str`, *optional*, defaults to `"composer1"`):
This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each
`"composer"`. Please make sure that the composet value is present in `composer_to_feature_token` in
`generation_config`. For an example please see
https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json .
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
[`~utils.ModelOutput`] types are:
- [`~generation.GreedySearchEncoderDecoderOutput`],
- [`~generation.SampleEncoderDecoderOutput`],
- [`~generation.BeamSearchEncoderDecoderOutput`],
- [`~generation.BeamSampleEncoderDecoderOutput`]
"""
if generation_config is None:
generation_config = self.generation_config
generation_config.update(**kwargs)
# check for composer_to_feature_token
if not hasattr(generation_config, "composer_to_feature_token"):
raise ValueError(
"`composer_to_feature_token` was not found! Please refer to "
"https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json"
"and parse a dict like that."
)
if len(generation_config.composer_to_feature_token) != self.config.composer_vocab_size:
raise ValueError(
"config.composer_vocab_size must be same as the number of keys in "
f"generation_config.composer_to_feature_token! "
f"Found {self.config.composer_vocab_size} vs {len(generation_config.composer_to_feature_token)}."
)
# to control the variation of generated MIDI tokens we concatenate mel-conditioner tokens(which depends on composer_token)
# at the front of input_features.
input_features, attention_mask = self.get_mel_conditioner_outputs(
input_features=input_features,
attention_mask=attention_mask,
composer=composer,
generation_config=generation_config,
)
return super().generate(
inputs=None,
inputs_embeds=input_features,
attention_mask=attention_mask,
generation_config=generation_config,
**kwargs,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)
def _reorder_cache(self, past_key_values, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past_key_values is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past_key_values
reordered_decoder_past = ()
for layer_past_states in past_key_values:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Processor class for Pop2Piano."""
import os
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ProcessorMixin
from ...tokenization_utils import BatchEncoding, PaddingStrategy, TruncationStrategy
from ...utils import TensorType
class Pop2PianoProcessor(ProcessorMixin):
r"""
Constructs an Pop2Piano processor which wraps a Pop2Piano Feature Extractor and Pop2Piano Tokenizer into a single
processor.
[`Pop2PianoProcessor`] offers all the functionalities of [`Pop2PianoFeatureExtractor`] and [`Pop2PianoTokenizer`].
See the docstring of [`~Pop2PianoProcessor.__call__`] and [`~Pop2PianoProcessor.decode`] for more information.
Args:
feature_extractor (`Pop2PianoFeatureExtractor`):
An instance of [`Pop2PianoFeatureExtractor`]. The feature extractor is a required input.
tokenizer (`Pop2PianoTokenizer`):
An instance of ['Pop2PianoTokenizer`]. The tokenizer is a required input.
"""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "Pop2PianoFeatureExtractor"
tokenizer_class = "Pop2PianoTokenizer"
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
def __call__(
self,
audio: Union[np.ndarray, List[float], List[np.ndarray]] = None,
sampling_rate: Union[int, List[int]] = None,
steps_per_beat: int = 2,
resample: Optional[bool] = True,
notes: Union[List, TensorType] = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
verbose: bool = True,
**kwargs,
) -> Union[BatchFeature, BatchEncoding]:
"""
This method uses [`Pop2PianoFeatureExtractor.__call__`] method to prepare log-mel-spectrograms for the model,
and [`Pop2PianoTokenizer.__call__`] to prepare token_ids from notes.
Please refer to the docstring of the above two methods for more information.
"""
# Since Feature Extractor needs both audio and sampling_rate and tokenizer needs both token_ids and
# feature_extractor_output, we must check for both.
if (audio is None and sampling_rate is None) and (notes is None):
raise ValueError(
"You have to specify at least audios and sampling_rate in order to use feature extractor or "
"notes to use the tokenizer part."
)
if audio is not None and sampling_rate is not None:
inputs = self.feature_extractor(
audio=audio,
sampling_rate=sampling_rate,
steps_per_beat=steps_per_beat,
resample=resample,
**kwargs,
)
if notes is not None:
encoded_token_ids = self.tokenizer(
notes=notes,
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
if notes is None:
return inputs
elif audio is None or sampling_rate is None:
return encoded_token_ids
else:
inputs["token_ids"] = encoded_token_ids["token_ids"]
return inputs
def batch_decode(
self,
token_ids,
feature_extractor_output: BatchFeature,
return_midi: bool = True,
) -> BatchEncoding:
"""
This method uses [`Pop2PianoTokenizer.batch_decode`] method to convert model generated token_ids to midi_notes.
Please refer to the docstring of the above two methods for more information.
"""
return self.tokenizer.batch_decode(
token_ids=token_ids, feature_extractor_output=feature_extractor_output, return_midi=return_midi
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
feature_extractor_input_names = self.feature_extractor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + feature_extractor_input_names))
def save_pretrained(self, save_directory, **kwargs):
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
return super().save_pretrained(save_directory, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(*args)
# coding=utf-8
# Copyright 2023 The Pop2Piano Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization class for Pop2Piano."""
import json
import os
from typing import List, Optional, Tuple, Union
import numpy as np
from ...feature_extraction_utils import BatchFeature
from ...tokenization_utils import AddedToken, BatchEncoding, PaddingStrategy, PreTrainedTokenizer, TruncationStrategy
from ...utils import TensorType, is_pretty_midi_available, logging, requires_backends, to_numpy
if is_pretty_midi_available():
import pretty_midi
logger = logging.get_logger(__name__)
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
VOCAB_FILES_NAMES = {
"vocab": "vocab.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab": {
"susnato/pop2piano_dev": "https://huggingface.co/susnato/pop2piano_dev/blob/main/vocab.json",
},
}
def token_time_to_note(number, cutoff_time_idx, current_idx):
current_idx += number
if cutoff_time_idx is not None:
current_idx = min(current_idx, cutoff_time_idx)
return current_idx
def token_note_to_note(number, current_velocity, default_velocity, note_onsets_ready, current_idx, notes):
if note_onsets_ready[number] is not None:
# offset with onset
onset_idx = note_onsets_ready[number]
if onset_idx < current_idx:
# Time shift after previous note_on
offset_idx = current_idx
notes.append([onset_idx, offset_idx, number, default_velocity])
onsets_ready = None if current_velocity == 0 else current_idx
note_onsets_ready[number] = onsets_ready
else:
note_onsets_ready[number] = current_idx
return notes
class Pop2PianoTokenizer(PreTrainedTokenizer):
"""
Constructs a Pop2Piano tokenizer. This tokenizer does not require training.
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.
Args:
vocab (`str`):
Path to the vocab file which contains the vocabulary.
default_velocity (`int`, *optional*, defaults to 77):
Determines the default velocity to be used while creating midi Notes.
num_bars (`int`, *optional*, defaults to 2):
Determines cutoff_time_idx in for each token.
"""
model_input_names = ["token_ids", "attention_mask"]
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
def __init__(
self,
vocab,
default_velocity=77,
num_bars=2,
unk_token="-1",
eos_token="1",
pad_token="0",
bos_token="2",
**kwargs,
):
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
super().__init__(
unk_token=unk_token,
eos_token=eos_token,
pad_token=pad_token,
bos_token=bos_token,
**kwargs,
)
self.default_velocity = default_velocity
self.num_bars = num_bars
# Load the vocab
with open(vocab, "rb") as file:
self.encoder = json.load(file)
# create mappings for encoder
self.decoder = {v: k for k, v in self.encoder.items()}
@property
def vocab_size(self):
"""Returns the vocabulary size of the tokenizer."""
return len(self.encoder)
def get_vocab(self):
"""Returns the vocabulary of the tokenizer."""
return dict(self.encoder, **self.added_tokens_encoder)
def _convert_id_to_token(self, token_id: int) -> list:
"""
Decodes the token ids generated by the transformer into notes.
Args:
token_id (`int`):
This denotes the ids generated by the transformers to be converted to Midi tokens.
Returns:
`List`: A list consists of token_type (`str`) and value (`int`).
"""
token_type_value = self.decoder.get(token_id, f"{self.unk_token}_TOKEN_TIME")
token_type_value = token_type_value.split("_")
token_type, value = "_".join(token_type_value[1:]), int(token_type_value[0])
return [token_type, value]
def _convert_token_to_id(self, token, token_type="TOKEN_TIME") -> int:
"""
Encodes the Midi tokens to transformer generated token ids.
Args:
token (`int`):
This denotes the token value.
token_type (`str`):
This denotes the type of the token. There are four types of midi tokens such as "TOKEN_TIME",
"TOKEN_VELOCITY", "TOKEN_NOTE" and "TOKEN_SPECIAL".
Returns:
`int`: returns the id of the token.
"""
return self.encoder.get(f"{token}_{token_type}", int(self.unk_token))
def relative_batch_tokens_ids_to_notes(
self,
tokens: np.ndarray,
beat_offset_idx: int,
bars_per_batch: int,
cutoff_time_idx: int,
):
"""
Converts relative tokens to notes which are then used to generate pretty midi object.
Args:
tokens (`numpy.ndarray`):
Tokens to be converted to notes.
beat_offset_idx (`int`):
Denotes beat offset index for each note in generated Midi.
bars_per_batch (`int`):
A parameter to control the Midi output generation.
cutoff_time_idx (`int`):
Denotes the cutoff time index for each note in generated Midi.
"""
notes = None
for index in range(len(tokens)):
_tokens = tokens[index]
_start_idx = beat_offset_idx + index * bars_per_batch * 4
_cutoff_time_idx = cutoff_time_idx + _start_idx
_notes = self.relative_tokens_ids_to_notes(
_tokens,
start_idx=_start_idx,
cutoff_time_idx=_cutoff_time_idx,
)
if len(_notes) == 0:
pass
elif notes is None:
notes = _notes
else:
notes = np.concatenate((notes, _notes), axis=0)
if notes is None:
return []
return notes
def relative_batch_tokens_ids_to_midi(
self,
tokens: np.ndarray,
beatstep: np.ndarray,
beat_offset_idx: int = 0,
bars_per_batch: int = 2,
cutoff_time_idx: int = 12,
):
"""
Converts tokens to Midi. This method calls `relative_batch_tokens_ids_to_notes` method to convert batch tokens
to notes then uses `notes_to_midi` method to convert them to Midi.
Args:
tokens (`numpy.ndarray`):
Denotes tokens which alongside beatstep will be converted to Midi.
beatstep (`np.ndarray`):
We get beatstep from feature extractor which is also used to get Midi.
beat_offset_idx (`int`, *optional*, defaults to 0):
Denotes beat offset index for each note in generated Midi.
bars_per_batch (`int`, *optional*, defaults to 2):
A parameter to control the Midi output generation.
cutoff_time_idx (`int`, *optional*, defaults to 12):
Denotes the cutoff time index for each note in generated Midi.
"""
beat_offset_idx = 0 if beat_offset_idx is None else beat_offset_idx
notes = self.relative_batch_tokens_ids_to_notes(
tokens=tokens,
beat_offset_idx=beat_offset_idx,
bars_per_batch=bars_per_batch,
cutoff_time_idx=cutoff_time_idx,
)
midi = self.notes_to_midi(notes, beatstep, offset_sec=beatstep[beat_offset_idx])
return midi
# Taken from the original code
# Please see https://github.com/sweetcocoa/pop2piano/blob/fac11e8dcfc73487513f4588e8d0c22a22f2fdc5/midi_tokenizer.py#L257
def relative_tokens_ids_to_notes(self, tokens: np.ndarray, start_idx: float, cutoff_time_idx: float = None):
"""
Converts relative tokens to notes which will then be used to create Pretty Midi objects.
Args:
tokens (`numpy.ndarray`):
Relative Tokens which will be converted to notes.
start_idx (`float`):
A parameter which denotes the starting index.
cutoff_time_idx (`float`, *optional*):
A parameter used while converting tokens to notes.
"""
words = [self._convert_id_to_token(token) for token in tokens]
current_idx = start_idx
current_velocity = 0
note_onsets_ready = [None for i in range(sum([k.endswith("NOTE") for k in self.encoder.keys()]) + 1)]
notes = []
for token_type, number in words:
if token_type == "TOKEN_SPECIAL":
if number == 1:
break
elif token_type == "TOKEN_TIME":
current_idx = token_time_to_note(
number=number, cutoff_time_idx=cutoff_time_idx, current_idx=current_idx
)
elif token_type == "TOKEN_VELOCITY":
current_velocity = number
elif token_type == "TOKEN_NOTE":
notes = token_note_to_note(
number=number,
current_velocity=current_velocity,
default_velocity=self.default_velocity,
note_onsets_ready=note_onsets_ready,
current_idx=current_idx,
notes=notes,
)
else:
raise ValueError("Token type not understood!")
for pitch, note_onset in enumerate(note_onsets_ready):
# force offset if no offset for each pitch
if note_onset is not None:
if cutoff_time_idx is None:
cutoff = note_onset + 1
else:
cutoff = max(cutoff_time_idx, note_onset + 1)
offset_idx = max(current_idx, cutoff)
notes.append([note_onset, offset_idx, pitch, self.default_velocity])
if len(notes) == 0:
return []
else:
notes = np.array(notes)
note_order = notes[:, 0] * 128 + notes[:, 1]
notes = notes[note_order.argsort()]
return notes
def notes_to_midi(self, notes: np.ndarray, beatstep: np.ndarray, offset_sec: int = 0.0):
"""
Converts notes to Midi.
Args:
notes (`numpy.ndarray`):
This is used to create Pretty Midi objects.
beatstep (`numpy.ndarray`):
This is the extrapolated beatstep that we get from feature extractor.
offset_sec (`int`, *optional*, defaults to 0.0):
This represents the offset seconds which is used while creating each Pretty Midi Note.
"""
requires_backends(self, ["pretty_midi"])
new_pm = pretty_midi.PrettyMIDI(resolution=384, initial_tempo=120.0)
new_inst = pretty_midi.Instrument(program=0)
new_notes = []
for onset_idx, offset_idx, pitch, velocity in notes:
new_note = pretty_midi.Note(
velocity=velocity,
pitch=pitch,
start=beatstep[onset_idx] - offset_sec,
end=beatstep[offset_idx] - offset_sec,
)
new_notes.append(new_note)
new_inst.notes = new_notes
new_pm.instruments.append(new_inst)
new_pm.remove_invalid_notes()
return new_pm
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Saves the tokenizer's vocabulary dictionary to the provided save_directory.
Args:
save_directory (`str`):
A path to the directory where to saved. It will be created if it doesn't exist.
filename_prefix (`Optional[str]`, *optional*):
A prefix to add to the names of the files saved by the tokenizer.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
# Save the encoder.
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"]
)
with open(out_vocab_file, "w") as file:
file.write(json.dumps(self.encoder))
return (out_vocab_file,)
def encode_plus(
self,
notes: Union[np.ndarray, List[pretty_midi.Note]],
truncation_strategy: Optional[TruncationStrategy] = None,
max_length: Optional[int] = None,
**kwargs,
) -> BatchEncoding:
r"""
This is the `encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
generated token ids. It only works on a single batch, to process multiple batches please use
`batch_encode_plus` or `__call__` method.
Args:
notes (`numpy.ndarray` of shape `[sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
This represents the midi notes. If `notes` is a `numpy.ndarray`:
- Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
If `notes` is a `list` containing `pretty_midi.Note` objects:
- Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
Indicates the truncation strategy that is going to be used during truncation.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
Returns:
`BatchEncoding` containing the tokens ids.
"""
requires_backends(self, ["pretty_midi"])
# check if notes is a pretty_midi object or not, if yes then extract the attributes and put them into a numpy
# array.
if isinstance(notes[0], pretty_midi.Note):
notes = np.array(
[[each_note.start, each_note.end, each_note.pitch, each_note.velocity] for each_note in notes]
).reshape(-1, 4)
# to round up all the values to the closest int values.
notes = np.round(notes).astype(np.int32)
max_time_idx = notes[:, :2].max()
times = [[] for i in range((max_time_idx + 1))]
for onset, offset, pitch, velocity in notes:
times[onset].append([pitch, velocity])
times[offset].append([pitch, 0])
tokens = []
current_velocity = 0
for i, time in enumerate(times):
if len(time) == 0:
continue
tokens.append(self._convert_token_to_id(i, "TOKEN_TIME"))
for pitch, velocity in time:
velocity = int(velocity > 0)
if current_velocity != velocity:
current_velocity = velocity
tokens.append(self._convert_token_to_id(velocity, "TOKEN_VELOCITY"))
tokens.append(self._convert_token_to_id(pitch, "TOKEN_NOTE"))
total_len = len(tokens)
# truncation
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
tokens, _, _ = self.truncate_sequences(
ids=tokens,
num_tokens_to_remove=total_len - max_length,
truncation_strategy=truncation_strategy,
**kwargs,
)
return BatchEncoding({"token_ids": tokens})
def batch_encode_plus(
self,
notes: Union[np.ndarray, List[pretty_midi.Note]],
truncation_strategy: Optional[TruncationStrategy] = None,
max_length: Optional[int] = None,
**kwargs,
) -> BatchEncoding:
r"""
This is the `batch_encode_plus` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer
generated token ids. It works on multiple batches by calling `encode_plus` multiple times in a loop.
Args:
notes (`numpy.ndarray` of shape `[batch_size, sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
This represents the midi notes. If `notes` is a `numpy.ndarray`:
- Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
If `notes` is a `list` containing `pretty_midi.Note` objects:
- Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`], *optional*):
Indicates the truncation strategy that is going to be used during truncation.
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
Returns:
`BatchEncoding` containing the tokens ids.
"""
encoded_batch_token_ids = []
for i in range(len(notes)):
encoded_batch_token_ids.append(
self.encode_plus(
notes[i],
truncation_strategy=truncation_strategy,
max_length=max_length,
**kwargs,
)["token_ids"]
)
return BatchEncoding({"token_ids": encoded_batch_token_ids})
def __call__(
self,
notes: Union[
np.ndarray,
List[pretty_midi.Note],
List[List[pretty_midi.Note]],
],
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
r"""
This is the `__call__` method for `Pop2PianoTokenizer`. It converts the midi notes to the transformer generated
token ids.
Args:
notes (`numpy.ndarray` of shape `[batch_size, max_sequence_length, 4]` or `list` of `pretty_midi.Note` objects):
This represents the midi notes.
If `notes` is a `numpy.ndarray`:
- Each sequence must have 4 values, they are `onset idx`, `offset idx`, `pitch` and `velocity`.
If `notes` is a `list` containing `pretty_midi.Note` objects:
- Each sequence must have 4 attributes, they are `start`, `end`, `pitch` and `velocity`.
padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`):
Activates and controls padding. Accepts the following values:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
Activates and controls truncation. Accepts the following values:
- `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will
truncate token by token, removing a token from the longest sequence in the pair if a pair of
sequences (or a batch of pairs) is provided.
- `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
- `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
- `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
greater than the model maximum admissible input size).
max_length (`int`, *optional*):
Controls the maximum length to use by one of the truncation/padding parameters. If left unset or set to
`None`, this will use the predefined model maximum length if a maximum length is required by one of the
truncation/padding parameters. If the model has no specific maximum input length (like XLNet)
truncation/padding to a maximum length will be deactivated.
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
return_attention_mask (`bool`, *optional*):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific tokenizer's default, defined by the `return_outputs` attribute.
[What are attention masks?](../glossary#attention-mask)
return_tensors (`str` or [`~file_utils.TensorType`], *optional*):
If set, will return tensors instead of list of python integers. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return Numpy `np.ndarray` objects.
verbose (`bool`, *optional*, defaults to `True`):
Whether or not to print more information and warnings.
Returns:
`BatchEncoding` containing the token_ids.
"""
# check if it is batched or not
# it is batched if its a list containing a list of `pretty_midi.Notes` where the outer list contains all the
# batches and the inner list contains all Notes for a single batch. Otherwise if np.ndarray is passed it will be
# considered batched if it has shape of `[batch_size, seqence_length, 4]` or ndim=3.
is_batched = notes.ndim == 3 if isinstance(notes, np.ndarray) else isinstance(notes[0], list)
# get the truncation and padding strategy
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
padding=padding,
truncation=truncation,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
verbose=verbose,
**kwargs,
)
if is_batched:
# If the user has not explicitly mentioned `return_attention_mask` as False, we change it to True
return_attention_mask = True if return_attention_mask is None else return_attention_mask
token_ids = self.batch_encode_plus(
notes=notes,
truncation_strategy=truncation_strategy,
max_length=max_length,
**kwargs,
)
else:
token_ids = self.encode_plus(
notes=notes,
truncation_strategy=truncation_strategy,
max_length=max_length,
**kwargs,
)
# since we already have truncated sequnences we are just left to do padding
token_ids = self.pad(
token_ids,
padding=padding_strategy,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_tensors=return_tensors,
verbose=verbose,
)
return token_ids
def batch_decode(
self,
token_ids,
feature_extractor_output: BatchFeature,
return_midi: bool = True,
):
r"""
This is the `batch_decode` method for `Pop2PianoTokenizer`. It converts the token_ids generated by the
transformer to midi_notes and returns them.
Args:
token_ids (`Union[np.ndarray, torch.Tensor, tf.Tensor]`):
Output token_ids of `Pop2PianoConditionalGeneration` model.
feature_extractor_output (`BatchFeature`):
Denotes the output of `Pop2PianoFeatureExtractor.__call__`. It must contain `"beatstep"` and
`"extrapolated_beatstep"`. Also `"attention_mask_beatsteps"` and
`"attention_mask_extrapolated_beatstep"`
should be present if they were returned by the feature extractor.
return_midi (`bool`, *optional*, defaults to `True`):
Whether to return midi object or not.
Returns:
If `return_midi` is True:
- `BatchEncoding` containing both `notes` and `pretty_midi.pretty_midi.PrettyMIDI` objects.
If `return_midi` is False:
- `BatchEncoding` containing `notes`.
"""
# check if they have attention_masks(attention_mask, attention_mask_beatsteps, attention_mask_extrapolated_beatstep) or not
attention_masks_present = bool(
hasattr(feature_extractor_output, "attention_mask")
and hasattr(feature_extractor_output, "attention_mask_beatsteps")
and hasattr(feature_extractor_output, "attention_mask_extrapolated_beatstep")
)
# if we are processing batched inputs then we must need attention_masks
if not attention_masks_present and feature_extractor_output["beatsteps"].shape[0] > 1:
raise ValueError(
"attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep must be present "
"for batched inputs! But one of them were not present."
)
# check for length mismatch between inputs_embeds, beatsteps and extrapolated_beatstep
if attention_masks_present:
# since we know about the number of examples in token_ids from attention_mask
if (
sum(feature_extractor_output["attention_mask"][:, 0] == 0)
!= feature_extractor_output["beatsteps"].shape[0]
or feature_extractor_output["beatsteps"].shape[0]
!= feature_extractor_output["extrapolated_beatstep"].shape[0]
):
raise ValueError(
"Length mistamtch between token_ids, beatsteps and extrapolated_beatstep! Found "
f"token_ids length - {token_ids.shape[0]}, beatsteps shape - {feature_extractor_output['beatsteps'].shape[0]} "
f"and extrapolated_beatsteps shape - {feature_extractor_output['extrapolated_beatstep'].shape[0]}"
)
if feature_extractor_output["attention_mask"].shape[0] != token_ids.shape[0]:
raise ValueError(
f"Found attention_mask of length - {feature_extractor_output['attention_mask'].shape[0]} but token_ids of length - {token_ids.shape[0]}"
)
else:
# if there is no attention mask present then it's surely a single example
if (
feature_extractor_output["beatsteps"].shape[0] != 1
or feature_extractor_output["extrapolated_beatstep"].shape[0] != 1
):
raise ValueError(
"Length mistamtch of beatsteps and extrapolated_beatstep! Since attention_mask is not present the number of examples must be 1, "
f"But found beatsteps length - {feature_extractor_output['beatsteps'].shape[0]}, extrapolated_beatsteps length - {feature_extractor_output['extrapolated_beatstep'].shape[0]}."
)
if attention_masks_present:
# check for zeros(since token_ids are seperated by zero arrays)
batch_idx = np.where(feature_extractor_output["attention_mask"][:, 0] == 0)[0]
else:
batch_idx = [token_ids.shape[0]]
notes_list = []
pretty_midi_objects_list = []
start_idx = 0
for index, end_idx in enumerate(batch_idx):
each_tokens_ids = token_ids[start_idx:end_idx]
# check where the whole example ended by searching for eos_token_id and getting the upper bound
each_tokens_ids = each_tokens_ids[:, : np.max(np.where(each_tokens_ids == int(self.eos_token))[1]) + 1]
beatsteps = feature_extractor_output["beatsteps"][index]
extrapolated_beatstep = feature_extractor_output["extrapolated_beatstep"][index]
# if attention mask is present then mask out real array/tensor
if attention_masks_present:
attention_mask_beatsteps = feature_extractor_output["attention_mask_beatsteps"][index]
attention_mask_extrapolated_beatstep = feature_extractor_output[
"attention_mask_extrapolated_beatstep"
][index]
beatsteps = beatsteps[: np.max(np.where(attention_mask_beatsteps == 1)[0]) + 1]
extrapolated_beatstep = extrapolated_beatstep[
: np.max(np.where(attention_mask_extrapolated_beatstep == 1)[0]) + 1
]
each_tokens_ids = to_numpy(each_tokens_ids)
beatsteps = to_numpy(beatsteps)
extrapolated_beatstep = to_numpy(extrapolated_beatstep)
pretty_midi_object = self.relative_batch_tokens_ids_to_midi(
tokens=each_tokens_ids,
beatstep=extrapolated_beatstep,
bars_per_batch=self.num_bars,
cutoff_time_idx=(self.num_bars + 1) * 4,
)
for note in pretty_midi_object.instruments[0].notes:
note.start += beatsteps[0]
note.end += beatsteps[0]
notes_list.append(note)
pretty_midi_objects_list.append(pretty_midi_object)
start_idx += end_idx + 1 # 1 represents the zero array
if return_midi:
return BatchEncoding({"notes": notes_list, "pretty_midi_objects": pretty_midi_objects_list})
return BatchEncoding({"notes": notes_list})
......@@ -32,6 +32,7 @@ is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse(
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
is_torch_less_than_1_11 = parsed_torch_version_base < version.parse("1.11")
is_torch_1_8_0 = parsed_torch_version_base == version.parse("1.8.0")
def softmax_backward_data(parent, grad_output, output, dim, self):
......
......@@ -57,6 +57,7 @@ from .utils import (
is_cython_available,
is_decord_available,
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flax_available,
is_ftfy_available,
......@@ -71,6 +72,7 @@ from .utils import (
is_pandas_available,
is_peft_available,
is_phonemizer_available,
is_pretty_midi_available,
is_pyctcdecode_available,
is_pytesseract_available,
is_pytest_available,
......@@ -825,6 +827,20 @@ def require_librosa(test_case):
return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
def require_essentia(test_case):
"""
Decorator marking a test that requires essentia
"""
return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
def require_pretty_midi(test_case):
"""
Decorator marking a test that requires pretty_midi
"""
return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)
def cmd_exists(cmd):
return shutil.which(cmd) is not None
......
......@@ -112,6 +112,7 @@ from .import_utils import (
is_datasets_available,
is_decord_available,
is_detectron2_available,
is_essentia_available,
is_faiss_available,
is_flax_available,
is_ftfy_available,
......@@ -130,6 +131,7 @@ from .import_utils import (
is_pandas_available,
is_peft_available,
is_phonemizer_available,
is_pretty_midi_available,
is_protobuf_available,
is_psutil_available,
is_py3nvml_available,
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class Pop2PianoFeatureExtractor(metaclass=DummyObject):
_backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"])
class Pop2PianoTokenizer(metaclass=DummyObject):
_backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"])
class Pop2PianoProcessor(metaclass=DummyObject):
_backends = ["essentia", "librosa", "pretty_midi", "scipy", "torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["essentia", "librosa", "pretty_midi", "scipy", "torch"])
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class Pop2PianoFeatureExtractor(metaclass=DummyObject):
_backends = ["music"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["music"])
class Pop2PianoTokenizer(metaclass=DummyObject):
_backends = ["music"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["music"])
......@@ -5935,6 +5935,23 @@ class PoolFormerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST = None
class Pop2PianoForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Pop2PianoPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -185,6 +185,22 @@ else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_essentia_available = importlib.util.find_spec("essentia") is not None
try:
_essentia_version = importlib.metadata.version("essentia")
logger.debug(f"Successfully imported essentia version {_essentia_version}")
except importlib.metadata.PackageNotFoundError:
_essentia_version = False
_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None
try:
_pretty_midi_version = importlib.metadata.version("pretty_midi")
logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}")
except importlib.metadata.PackageNotFoundError:
_pretty_midi_available = False
ccl_version = "N/A"
_is_ccl_available = (
importlib.util.find_spec("torch_ccl") is not None
......@@ -242,6 +258,14 @@ def is_librosa_available():
return _librosa_available
def is_essentia_available():
return _essentia_available
def is_pretty_midi_available():
return _pretty_midi_available
def is_torch_cuda_available():
if is_torch_available():
import torch
......@@ -986,6 +1010,27 @@ CCL_IMPORT_ERROR = """
Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
ESSENTIA_IMPORT_ERROR = """
{0} requires essentia library. But that was not found in your environment. You can install them with pip:
`pip install essentia==2.1b6.dev1034`
Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
LIBROSA_IMPORT_ERROR = """
{0} requires thes librosa library. But that was not found in your environment. You can install them with pip:
`pip install librosa`
Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
PRETTY_MIDI_IMPORT_ERROR = """
{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip:
`pip install pretty_midi`
Please note that you may need to restart your runtime after installation.
"""
DECORD_IMPORT_ERROR = """
{0} requires the decord library but it was not found in your environment. You can install it with pip: `pip install
decord`. Please note that you may need to restart your runtime after installation.
......@@ -1011,11 +1056,14 @@ BACKENDS_MAPPING = OrderedDict(
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)),
("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)),
("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)),
("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)),
("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)),
("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)),
("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)),
("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)),
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from transformers.testing_utils import (
check_json_file_has_correct_format,
require_essentia,
require_librosa,
require_scipy,
require_tf,
require_torch,
)
from transformers.utils.import_utils import (
is_essentia_available,
is_librosa_available,
is_scipy_available,
is_torch_available,
)
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
requirements_available = (
is_torch_available() and is_essentia_available() and is_scipy_available() and is_librosa_available()
)
if requirements_available:
import torch
from transformers import Pop2PianoFeatureExtractor
class Pop2PianoFeatureExtractionTester(unittest.TestCase):
def __init__(
self,
parent,
n_bars=2,
sample_rate=22050,
use_mel=True,
padding_value=0,
vocab_size_special=4,
vocab_size_note=128,
vocab_size_velocity=2,
vocab_size_time=100,
):
self.parent = parent
self.n_bars = n_bars
self.sample_rate = sample_rate
self.use_mel = use_mel
self.padding_value = padding_value
self.vocab_size_special = vocab_size_special
self.vocab_size_note = vocab_size_note
self.vocab_size_velocity = vocab_size_velocity
self.vocab_size_time = vocab_size_time
def prepare_feat_extract_dict(self):
return {
"n_bars": self.n_bars,
"sample_rate": self.sample_rate,
"use_mel": self.use_mel,
"padding_value": self.padding_value,
"vocab_size_special": self.vocab_size_special,
"vocab_size_note": self.vocab_size_note,
"vocab_size_velocity": self.vocab_size_velocity,
"vocab_size_time": self.vocab_size_time,
}
@require_torch
@require_essentia
@require_librosa
@require_scipy
class Pop2PianoFeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = Pop2PianoFeatureExtractor if requirements_available else None
def setUp(self):
self.feat_extract_tester = Pop2PianoFeatureExtractionTester(self)
def test_feat_extract_from_and_save_pretrained(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
saved_file = feat_extract_first.save_pretrained(tmpdirname)[0]
check_json_file_has_correct_format(saved_file)
feat_extract_second = self.feature_extraction_class.from_pretrained(tmpdirname)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.use_mel
mel_2 = feat_extract_second.use_mel
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_feat_extract_to_json_file(self):
feat_extract_first = self.feature_extraction_class(**self.feat_extract_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "feat_extract.json")
feat_extract_first.to_json_file(json_file_path)
feat_extract_second = self.feature_extraction_class.from_json_file(json_file_path)
dict_first = feat_extract_first.to_dict()
dict_second = feat_extract_second.to_dict()
mel_1 = feat_extract_first.use_mel
mel_2 = feat_extract_second.use_mel
self.assertTrue(np.allclose(mel_1, mel_2))
self.assertEqual(dict_first, dict_second)
def test_call(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input = np.zeros([1000000], dtype=np.float32)
input_features = feature_extractor(speech_input, sampling_rate=16_000, return_tensors="np")
self.assertTrue(input_features.input_features.ndim == 3)
self.assertEqual(input_features.input_features.shape[-1], 512)
self.assertTrue(input_features.beatsteps.ndim == 2)
self.assertTrue(input_features.extrapolated_beatstep.ndim == 2)
def test_integration(self):
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select([0])["audio"]
input_speech = [x["array"] for x in speech_samples][0]
sampling_rate = [x["sampling_rate"] for x in speech_samples][0]
feaure_extractor = Pop2PianoFeatureExtractor.from_pretrained("sweetcocoa/pop2piano")
input_features = feaure_extractor(
input_speech, sampling_rate=sampling_rate, return_tensors="pt"
).input_features
EXPECTED_INPUT_FEATURES = torch.tensor(
[[-7.1493, -6.8701, -4.3214], [-5.9473, -5.7548, -3.8438], [-6.1324, -5.9018, -4.3778]]
)
self.assertTrue(torch.allclose(input_features[0, :3, :3], EXPECTED_INPUT_FEATURES, atol=1e-4))
def test_attention_mask(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2],
sampling_rate=[44_100, 16_000],
return_tensors="np",
return_attention_mask=True,
)
self.assertTrue(hasattr(input_features, "attention_mask"))
# check shapes
self.assertTrue(input_features["attention_mask"].ndim == 2)
self.assertEqual(input_features["attention_mask_beatsteps"].shape[0], 2)
self.assertEqual(input_features["attention_mask_extrapolated_beatstep"].shape[0], 2)
# check if they are any values except 0 and 1
self.assertTrue(np.max(input_features["attention_mask"]) == 1)
self.assertTrue(np.max(input_features["attention_mask_beatsteps"]) == 1)
self.assertTrue(np.max(input_features["attention_mask_extrapolated_beatstep"]) == 1)
self.assertTrue(np.min(input_features["attention_mask"]) == 0)
self.assertTrue(np.min(input_features["attention_mask_beatsteps"]) == 0)
self.assertTrue(np.min(input_features["attention_mask_extrapolated_beatstep"]) == 0)
def test_batch_feature(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_attention_mask=True,
)
self.assertEqual(len(input_features["input_features"].shape), 3)
# check shape
self.assertEqual(input_features["beatsteps"].shape[0], 3)
self.assertEqual(input_features["extrapolated_beatstep"].shape[0], 3)
def test_batch_feature_np(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="np",
return_attention_mask=True,
)
# check np array or not
self.assertEqual(type(input_features["input_features"]), np.ndarray)
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
def test_batch_feature_pt(self):
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="pt",
return_attention_mask=True,
)
# check pt tensor or not
self.assertEqual(type(input_features["input_features"]), torch.Tensor)
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
@require_tf
def test_batch_feature_tf(self):
import tensorflow as tf
feature_extractor = self.feature_extraction_class(**self.feat_extract_tester.prepare_feat_extract_dict())
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
speech_input2 = np.ones([2_000_000], dtype=np.float32)
speech_input3 = np.random.randint(low=0, high=10, size=500_000).astype(np.float32)
input_features = feature_extractor(
[speech_input1, speech_input2, speech_input3],
sampling_rate=[44_100, 16_000, 48_000],
return_tensors="tf",
return_attention_mask=True,
)
# check tf tensor or not
self.assertTrue(tf.is_tensor(input_features["input_features"]))
# check shape
self.assertEqual(len(input_features["input_features"].shape), 3)
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_accepts_tensors_pt(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_accepts_tensors_tf(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_from_list(self):
pass
@unittest.skip(
"Pop2PianoFeatureExtractor does not supports padding externally (while processing audios in batches padding is automatically applied to max_length)"
)
def test_padding_from_array(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not support truncation")
def test_attention_mask_with_truncation(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not supports truncation")
def test_truncation_from_array(self):
pass
@unittest.skip("Pop2PianoFeatureExtractor does not supports truncation")
def test_truncation_from_list(self):
pass
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Pop2Piano model. """
import copy
import tempfile
import unittest
import numpy as np
from datasets import load_dataset
from transformers import Pop2PianoConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.testing_utils import (
require_essentia,
require_librosa,
require_onnx,
require_scipy,
require_torch,
slow,
torch_device,
)
from transformers.utils import is_essentia_available, is_librosa_available, is_scipy_available, is_torch_available
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor
if is_torch_available():
import torch
from transformers import Pop2PianoForConditionalGeneration
from transformers.models.pop2piano.modeling_pop2piano import POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST
from transformers.pytorch_utils import is_torch_1_8_0
else:
is_torch_1_8_0 = False
@require_torch
class Pop2PianoModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
encoder_seq_length=7,
decoder_seq_length=9,
# For common tests
is_training=False,
use_attention_mask=True,
use_labels=True,
hidden_size=64,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
decoder_layers=None,
):
self.parent = parent
self.batch_size = batch_size
self.encoder_seq_length = encoder_seq_length
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = decoder_layers
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size)
decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
decoder_attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = (
ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) if self.use_labels else None
)
return self.get_config(), input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels
def get_pipeline_config(self):
return Pop2PianoConfig(
vocab_size=166, # Pop2Piano forces 100 extra tokens
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
)
def get_config(self):
return Pop2PianoConfig(
vocab_size=self.vocab_size,
d_model=self.hidden_size,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
)
def check_prepare_lm_labels_via_shift_left(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
# make sure that lm_labels are correctly padded from the right
lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)
# add causal pad token mask
triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
decoder_input_ids = model._shift_right(lm_labels)
for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
# first item
self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
if i < decoder_input_ids_slice.shape[-1]:
if i < decoder_input_ids.shape[-1] - 1:
# items before diagonal
self.parent.assertListEqual(
decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
)
# pad items after diagonal
if i < decoder_input_ids.shape[-1] - 2:
self.parent.assertListEqual(
decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
)
else:
# all items after square
self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
def create_and_check_model(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_past = result.past_key_values
encoder_output = result.encoder_last_hidden_state
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
# There should be `num_layers` key value embeddings stored in decoder_past
self.parent.assertEqual(len(decoder_past), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
self.parent.assertEqual(len(decoder_past[0]), 4)
def create_and_check_with_lm_head(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
outputs = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=lm_labels,
)
self.parent.assertEqual(len(outputs), 4)
self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size))
self.parent.assertEqual(outputs["loss"].size(), ())
def create_and_check_decoder_model_past(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).get_decoder().to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past_key_values = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).get_decoder()
model.to(torch_device)
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).get_decoder().to(torch_device).eval()
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([attention_mask, next_mask], dim=-1)
output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_generate_with_past_key_values(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
torch.manual_seed(0)
output_without_past_cache = model.generate(
input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False
)
torch.manual_seed(0)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_model_fp16_forward(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).half().eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[
"encoder_last_hidden_state"
]
self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_encoder_decoder_shared_weights(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
for model_class in [Pop2PianoForConditionalGeneration]:
torch.manual_seed(0)
model = model_class(config=config).to(torch_device).eval()
# load state dict copies weights but does not tie them
model.encoder.load_state_dict(model.decoder.state_dict(), strict=False)
torch.manual_seed(0)
tied_config = copy.deepcopy(config)
tied_config.tie_encoder_decoder = True
tied_model = model_class(config=tied_config).to(torch_device).eval()
model_result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx], tied_model_result[0][0, :, random_slice_idx], atol=1e-4
)
)
# check that outputs after saving and loading are equal
with tempfile.TemporaryDirectory() as tmpdirname:
tied_model.save_pretrained(tmpdirname)
tied_model = model_class.from_pretrained(tmpdirname)
tied_model.to(torch_device)
tied_model.eval()
# check that models has less parameters
self.parent.assertLess(
sum(p.numel() for p in tied_model.parameters()), sum(p.numel() for p in model.parameters())
)
random_slice_idx = ids_tensor((1,), model_result[0].shape[-1]).item()
tied_model_result = tied_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
# check that outputs are equal
self.parent.assertTrue(
torch.allclose(
model_result[0][0, :, random_slice_idx],
tied_model_result[0][0, :, random_slice_idx],
atol=1e-4,
)
)
def check_resize_embeddings_pop2piano_v1_1(
self,
config,
):
prev_vocab_size = config.vocab_size
config.tie_word_embeddings = False
model = Pop2PianoForConditionalGeneration(config=config).to(torch_device).eval()
model.resize_token_embeddings(prev_vocab_size - 10)
self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10)
self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"use_cache": False,
}
return config, inputs_dict
@require_torch
class Pop2PianoModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (Pop2PianoForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = ()
all_parallelizable_model_classes = ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
test_model_parallel = False
is_encoder_decoder = True
def setUp(self):
self.model_tester = Pop2PianoModelTester(self)
self.config_tester = ConfigTester(self, config_class=Pop2PianoConfig, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_shift_right(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_prepare_lm_labels_via_shift_left(*config_and_inputs)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_config_and_model_silu_gated(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
config.feed_forward_proj = "gated-silu"
self.model_tester.create_and_check_model(*config_and_inputs)
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_decoder_model_past_with_3d_attn_mask(self):
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
) = self.model_tester.prepare_config_and_inputs()
attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.encoder_seq_length, self.model_tester.encoder_seq_length],
vocab_size=2,
)
decoder_attention_mask = ids_tensor(
[self.model_tester.batch_size, self.model_tester.decoder_seq_length, self.model_tester.decoder_seq_length],
vocab_size=2,
)
self.model_tester.create_and_check_decoder_model_attention_mask_past(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_encoder_decoder_shared_weights(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_encoder_decoder_shared_weights(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
def test_v1_1_resize_embeddings(self):
config = self.model_tester.prepare_config_and_inputs()[0]
self.model_tester.check_resize_embeddings_pop2piano_v1_1(config)
@slow
def test_model_from_pretrained(self):
for model_name in POP2PIANO_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = Pop2PianoForConditionalGeneration.from_pretrained(model_name)
self.assertIsNotNone(model)
@require_onnx
@unittest.skipIf(
is_torch_1_8_0,
reason="Test has a segmentation fault on torch 1.8.0",
)
def test_export_to_onnx(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
model = Pop2PianoForConditionalGeneration(config_and_inputs[0]).to(torch_device)
with tempfile.TemporaryDirectory() as tmpdirname:
torch.onnx.export(
model,
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
f"{tmpdirname}/Pop2Piano_test.onnx",
export_params=True,
opset_version=9,
input_names=["input_ids", "decoder_input_ids"],
)
def test_pass_with_input_features(self):
input_features = BatchFeature(
{
"input_features": torch.rand((75, 100, 512)).type(torch.float32),
"beatsteps": torch.randint(size=(1, 955), low=0, high=100).type(torch.float32),
"extrapolated_beatstep": torch.randint(size=(1, 900), low=0, high=100).type(torch.float32),
}
)
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
model_opts = model.generate(input_features=input_features["input_features"], return_dict_in_generate=True)
self.assertEqual(model_opts.sequences.ndim, 2)
def test_pass_with_batched_input_features(self):
input_features = BatchFeature(
{
"input_features": torch.rand((220, 70, 512)).type(torch.float32),
"beatsteps": torch.randint(size=(5, 955), low=0, high=100).type(torch.float32),
"extrapolated_beatstep": torch.randint(size=(5, 900), low=0, high=100).type(torch.float32),
"attention_mask": torch.concatenate(
[
torch.ones([120, 70], dtype=torch.int32),
torch.zeros([1, 70], dtype=torch.int32),
torch.ones([50, 70], dtype=torch.int32),
torch.zeros([1, 70], dtype=torch.int32),
torch.ones([47, 70], dtype=torch.int32),
torch.zeros([1, 70], dtype=torch.int32),
],
axis=0,
),
"attention_mask_beatsteps": torch.ones((5, 955)).type(torch.int32),
"attention_mask_extrapolated_beatstep": torch.ones((5, 900)).type(torch.int32),
}
)
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
model_opts = model.generate(
input_features=input_features["input_features"],
attention_mask=input_features["attention_mask"],
return_dict_in_generate=True,
)
self.assertEqual(model_opts.sequences.ndim, 2)
@require_torch
class Pop2PianoModelIntegrationTests(unittest.TestCase):
@slow
def test_mel_conditioner_integration(self):
composer = "composer1"
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
input_embeds = torch.ones([10, 100, 512])
composer_value = model.generation_config.composer_to_feature_token[composer]
composer_value = torch.tensor(composer_value)
composer_value = composer_value.repeat(input_embeds.size(0))
outputs = model.mel_conditioner(
input_embeds, composer_value, min(model.generation_config.composer_to_feature_token.values())
)
# check shape
self.assertEqual(outputs.size(), torch.Size([10, 101, 512]))
# check values
EXPECTED_OUTPUTS = torch.tensor(
[[1.0475305318832397, 0.29052114486694336, -0.47778210043907166], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
)
self.assertTrue(torch.allclose(outputs[0, :3, :3], EXPECTED_OUTPUTS, atol=1e-4))
@slow
@require_essentia
@require_librosa
@require_scipy
def test_full_model_integration(self):
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
from transformers import Pop2PianoProcessor
speech_input1 = np.zeros([1_000_000], dtype=np.float32)
sampling_rate = 44_100
processor = Pop2PianoProcessor.from_pretrained("sweetcocoa/pop2piano")
input_features = processor.feature_extractor(
speech_input1, sampling_rate=sampling_rate, return_tensors="pt"
)
model = Pop2PianoForConditionalGeneration.from_pretrained("sweetcocoa/pop2piano")
outputs = model.generate(
input_features=input_features["input_features"], return_dict_in_generate=True
).sequences
# check for shapes
self.assertEqual(outputs.size(0), 70)
# check for values
self.assertEqual(outputs[0, :2].detach().cpu().numpy().tolist(), [0, 1])
# This is the test for a real music from K-Pop genre.
@slow
@require_essentia
@require_librosa
@require_scipy
def test_real_music(self):
if is_librosa_available() and is_scipy_available() and is_essentia_available() and is_torch_available():
from transformers import Pop2PianoFeatureExtractor, Pop2PianoTokenizer
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
model.eval()
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
ds = load_dataset("sweetcocoa/pop2piano_ci", split="test")
output_fe = feature_extractor(
ds["audio"][0]["array"], sampling_rate=ds["audio"][0]["sampling_rate"], return_tensors="pt"
)
output_model = model.generate(input_features=output_fe["input_features"], composer="composer1")
output_tokenizer = tokenizer.batch_decode(token_ids=output_model, feature_extractor_output=output_fe)
pretty_midi_object = output_tokenizer["pretty_midi_objects"][0]
# Checking if no of notes are same
self.assertEqual(len(pretty_midi_object.instruments[0].notes), 59)
predicted_timings = []
for i in pretty_midi_object.instruments[0].notes:
predicted_timings.append(i.start)
# Checking note start timings(first 6)
EXPECTED_START_TIMINGS = [
0.4876190423965454,
0.7314285635948181,
0.9752380847930908,
1.4396371841430664,
1.6718367338180542,
1.904036283493042,
]
np.allclose(EXPECTED_START_TIMINGS, predicted_timings[:6])
# Checking note end timings(last 6)
EXPECTED_END_TIMINGS = [
12.341403007507324,
12.567797183990479,
12.567797183990479,
12.567797183990479,
12.794191360473633,
12.794191360473633,
]
np.allclose(EXPECTED_END_TIMINGS, predicted_timings[-6:])
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
import numpy as np
import pytest
from datasets import load_dataset
from transformers.testing_utils import (
require_essentia,
require_librosa,
require_pretty_midi,
require_scipy,
require_torch,
)
from transformers.tokenization_utils import BatchEncoding
from transformers.utils.import_utils import (
is_essentia_available,
is_librosa_available,
is_pretty_midi_available,
is_scipy_available,
is_torch_available,
)
requirements_available = (
is_torch_available()
and is_essentia_available()
and is_scipy_available()
and is_librosa_available()
and is_pretty_midi_available()
)
if requirements_available:
import pretty_midi
from transformers import (
Pop2PianoFeatureExtractor,
Pop2PianoForConditionalGeneration,
Pop2PianoProcessor,
Pop2PianoTokenizer,
)
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
@require_scipy
@require_torch
@require_librosa
@require_essentia
@require_pretty_midi
class Pop2PianoProcessorTest(unittest.TestCase):
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
feature_extractor = Pop2PianoFeatureExtractor.from_pretrained("susnato/pop2piano_dev")
tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
processor = Pop2PianoProcessor(feature_extractor, tokenizer)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return Pop2PianoTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_feature_extractor(self, **kwargs):
return Pop2PianoFeatureExtractor.from_pretrained(self.tmpdirname, **kwargs)
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def test_save_load_pretrained_additional_features(self):
processor = Pop2PianoProcessor(
tokenizer=self.get_tokenizer(),
feature_extractor=self.get_feature_extractor(),
)
processor.save_pretrained(self.tmpdirname)
tokenizer_add_kwargs = self.get_tokenizer(
unk_token="-1",
eos_token="1",
pad_token="0",
bos_token="2",
)
feature_extractor_add_kwargs = self.get_feature_extractor()
processor = Pop2PianoProcessor.from_pretrained(
self.tmpdirname,
unk_token="-1",
eos_token="1",
pad_token="0",
bos_token="2",
)
self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab())
self.assertIsInstance(processor.tokenizer, Pop2PianoTokenizer)
self.assertEqual(processor.feature_extractor.to_json_string(), feature_extractor_add_kwargs.to_json_string())
self.assertIsInstance(processor.feature_extractor, Pop2PianoFeatureExtractor)
def get_inputs(self):
"""get inputs for both feature extractor and tokenizer"""
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
speech_samples = ds.sort("id").select([0])["audio"]
input_speech = [x["array"] for x in speech_samples][0]
sampling_rate = [x["sampling_rate"] for x in speech_samples][0]
feature_extractor_outputs = self.get_feature_extractor()(
audio=input_speech, sampling_rate=sampling_rate, return_tensors="pt"
)
model = Pop2PianoForConditionalGeneration.from_pretrained("susnato/pop2piano_dev")
token_ids = model.generate(input_features=feature_extractor_outputs["input_features"], composer="composer1")
dummy_notes = [
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
pretty_midi.Note(start=0.673379, end=0.905578, pitch=73, velocity=77),
pretty_midi.Note(start=0.905578, end=2.159456, pitch=73, velocity=77),
pretty_midi.Note(start=1.114558, end=2.159456, pitch=78, velocity=77),
pretty_midi.Note(start=1.323537, end=1.532517, pitch=80, velocity=77),
],
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
],
]
return input_speech, sampling_rate, token_ids, dummy_notes
def test_feature_extractor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
input_speech, sampling_rate, _, _ = self.get_inputs()
feature_extractor_outputs = feature_extractor(
audio=input_speech, sampling_rate=sampling_rate, return_tensors="np"
)
processor_outputs = processor(audio=input_speech, sampling_rate=sampling_rate, return_tensors="np")
for key in feature_extractor_outputs.keys():
self.assertTrue(np.allclose(feature_extractor_outputs[key], processor_outputs[key], atol=1e-4))
def test_processor_batch_decode(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
audio, sampling_rate, token_ids, _ = self.get_inputs()
feature_extractor_output = feature_extractor(audio=audio, sampling_rate=sampling_rate, return_tensors="pt")
encoded_processor = processor.batch_decode(
token_ids=token_ids,
feature_extractor_output=feature_extractor_output,
return_midi=True,
)
encoded_tokenizer = tokenizer.batch_decode(
token_ids=token_ids,
feature_extractor_output=feature_extractor_output,
return_midi=True,
)
# check start timings
encoded_processor_start_timings = [token.start for token in encoded_processor["notes"]]
encoded_tokenizer_start_timings = [token.start for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_start_timings, encoded_tokenizer_start_timings)
# check end timings
encoded_processor_end_timings = [token.end for token in encoded_processor["notes"]]
encoded_tokenizer_end_timings = [token.end for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_end_timings, encoded_tokenizer_end_timings)
# check pitch
encoded_processor_pitch = [token.pitch for token in encoded_processor["notes"]]
encoded_tokenizer_pitch = [token.pitch for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_pitch, encoded_tokenizer_pitch)
# check velocity
encoded_processor_velocity = [token.velocity for token in encoded_processor["notes"]]
encoded_tokenizer_velocity = [token.velocity for token in encoded_tokenizer["notes"]]
self.assertListEqual(encoded_processor_velocity, encoded_tokenizer_velocity)
def test_tokenizer_call(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
_, _, _, notes = self.get_inputs()
encoded_processor = processor(
notes=notes,
)
self.assertTrue(isinstance(encoded_processor, BatchEncoding))
def test_processor(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
audio, sampling_rate, _, notes = self.get_inputs()
inputs = processor(
audio=audio,
sampling_rate=sampling_rate,
notes=notes,
)
self.assertListEqual(
list(inputs.keys()),
["input_features", "beatsteps", "extrapolated_beatstep", "token_ids"],
)
# test if it raises when no input is passed
with pytest.raises(ValueError):
processor()
def test_model_input_names(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
processor = Pop2PianoProcessor(
tokenizer=tokenizer,
feature_extractor=feature_extractor,
)
audio, sampling_rate, _, notes = self.get_inputs()
feature_extractor(audio, sampling_rate, return_tensors="pt")
inputs = processor(
audio=audio,
sampling_rate=sampling_rate,
notes=notes,
)
self.assertListEqual(
list(inputs.keys()),
["input_features", "beatsteps", "extrapolated_beatstep", "token_ids"],
)
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Please note that Pop2PianoTokenizer is too far from our usual tokenizers and thus cannot use the TokenizerTesterMixin class.
"""
import os
import pickle
import shutil
import tempfile
import unittest
from transformers.feature_extraction_utils import BatchFeature
from transformers.testing_utils import (
is_pretty_midi_available,
is_torch_available,
require_pretty_midi,
require_torch,
)
from transformers.tokenization_utils import BatchEncoding
if is_torch_available():
import torch
requirements_available = is_torch_available() and is_pretty_midi_available()
if requirements_available:
import pretty_midi
from transformers import Pop2PianoTokenizer
## TODO : changing checkpoints from `susnato/pop2piano_dev` to `sweetcocoa/pop2piano` after the PR is approved
@require_torch
@require_pretty_midi
class Pop2PianoTokenizerTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.tokenizer = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev")
def get_input_notes(self):
notes = [
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
pretty_midi.Note(start=0.673379, end=0.905578, pitch=73, velocity=77),
pretty_midi.Note(start=0.905578, end=2.159456, pitch=73, velocity=77),
pretty_midi.Note(start=1.114558, end=2.159456, pitch=78, velocity=77),
pretty_midi.Note(start=1.323537, end=1.532517, pitch=80, velocity=77),
],
[
pretty_midi.Note(start=0.441179, end=2.159456, pitch=70, velocity=77),
],
]
return notes
def test_call(self):
notes = self.get_input_notes()
output = self.tokenizer(
notes,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=10,
return_attention_mask=True,
)
# check the output type
self.assertTrue(isinstance(output, BatchEncoding))
# check the values
expected_output_token_ids = torch.tensor(
[[134, 133, 74, 135, 77, 132, 77, 133, 77, 82], [134, 133, 74, 136, 132, 74, 134, 134, 134, 134]]
)
expected_output_attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])
self.assertTrue(torch.allclose(output["token_ids"], expected_output_token_ids, atol=1e-4))
self.assertTrue(torch.allclose(output["attention_mask"], expected_output_attention_mask, atol=1e-4))
def test_batch_decode(self):
# test batch decode with model, feature-extractor outputs(beatsteps, extrapolated_beatstep)
# Please note that this test does not test the accuracy of the outputs, instead it is designed to make sure that
# the tokenizer's batch_decode can deal with attention_mask in feature-extractor outputs. For the accuracy check
# please see the `test_batch_decode_outputs` test.
model_output = torch.concatenate(
[
torch.randint(size=[120, 96], low=0, high=70, dtype=torch.long),
torch.zeros(size=[1, 96], dtype=torch.long),
torch.randint(size=[50, 96], low=0, high=40, dtype=torch.long),
torch.zeros(size=[1, 96], dtype=torch.long),
],
axis=0,
)
input_features = BatchFeature(
{
"beatsteps": torch.ones([2, 955]),
"extrapolated_beatstep": torch.ones([2, 1000]),
"attention_mask": torch.concatenate(
[
torch.ones([120, 96], dtype=torch.long),
torch.zeros([1, 96], dtype=torch.long),
torch.ones([50, 96], dtype=torch.long),
torch.zeros([1, 96], dtype=torch.long),
],
axis=0,
),
"attention_mask_beatsteps": torch.ones([2, 955]),
"attention_mask_extrapolated_beatstep": torch.ones([2, 1000]),
}
)
output = self.tokenizer.batch_decode(token_ids=model_output, feature_extractor_output=input_features)[
"pretty_midi_objects"
]
# check length
self.assertTrue(len(output) == 2)
# check object type
self.assertTrue(isinstance(output[0], pretty_midi.pretty_midi.PrettyMIDI))
self.assertTrue(isinstance(output[1], pretty_midi.pretty_midi.PrettyMIDI))
def test_batch_decode_outputs(self):
# test batch decode with model, feature-extractor outputs(beatsteps, extrapolated_beatstep)
# Please note that this test tests the accuracy of the outputs of the tokenizer's `batch_decode` method.
model_output = torch.tensor(
[
[134, 133, 74, 135, 77, 82, 84, 136, 132, 74, 77, 82, 84],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
]
)
input_features = BatchEncoding(
{
"beatsteps": torch.tensor([[0.0697, 0.1103, 0.1509, 0.1916]]),
"extrapolated_beatstep": torch.tensor([[0.0000, 0.0406, 0.0813, 0.1219]]),
}
)
output = self.tokenizer.batch_decode(token_ids=model_output, feature_extractor_output=input_features)
# check outputs
self.assertEqual(len(output["notes"]), 4)
predicted_start_timings, predicted_end_timings = [], []
for i in output["notes"]:
predicted_start_timings.append(i.start)
predicted_end_timings.append(i.end)
# Checking note start timings
expected_start_timings = torch.tensor(
[
0.069700,
0.110300,
0.110300,
0.110300,
]
)
predicted_start_timings = torch.tensor(predicted_start_timings)
self.assertTrue(torch.allclose(expected_start_timings, predicted_start_timings, atol=1e-4))
# Checking note end timings
expected_end_timings = torch.tensor(
[
0.191600,
0.191600,
0.191600,
0.191600,
]
)
predicted_end_timings = torch.tensor(predicted_end_timings)
self.assertTrue(torch.allclose(expected_end_timings, predicted_end_timings, atol=1e-4))
def test_get_vocab(self):
vocab_dict = self.tokenizer.get_vocab()
self.assertIsInstance(vocab_dict, dict)
self.assertGreaterEqual(len(self.tokenizer), len(vocab_dict))
vocab = [self.tokenizer.convert_ids_to_tokens(i) for i in range(len(self.tokenizer))]
self.assertEqual(len(vocab), len(self.tokenizer))
self.tokenizer.add_tokens(["asdfasdfasdfasdf"])
vocab = [self.tokenizer.convert_ids_to_tokens(i) for i in range(len(self.tokenizer))]
self.assertEqual(len(vocab), len(self.tokenizer))
def test_save_and_load_tokenizer(self):
tmpdirname = tempfile.mkdtemp()
sample_notes = self.get_input_notes()
self.tokenizer.add_tokens(["bim", "bambam"])
additional_special_tokens = self.tokenizer.additional_special_tokens
additional_special_tokens.append("new_additional_special_token")
self.tokenizer.add_special_tokens({"additional_special_tokens": additional_special_tokens})
before_token_ids = self.tokenizer(sample_notes)["token_ids"]
before_vocab = self.tokenizer.get_vocab()
self.tokenizer.save_pretrained(tmpdirname)
after_tokenizer = self.tokenizer.__class__.from_pretrained(tmpdirname)
after_token_ids = after_tokenizer(sample_notes)["token_ids"]
after_vocab = after_tokenizer.get_vocab()
self.assertDictEqual(before_vocab, after_vocab)
self.assertListEqual(before_token_ids, after_token_ids)
self.assertIn("bim", after_vocab)
self.assertIn("bambam", after_vocab)
self.assertIn("new_additional_special_token", after_tokenizer.additional_special_tokens)
shutil.rmtree(tmpdirname)
def test_pickle_tokenizer(self):
tmpdirname = tempfile.mkdtemp()
notes = self.get_input_notes()
subwords = self.tokenizer(notes)["token_ids"]
filename = os.path.join(tmpdirname, "tokenizer.bin")
with open(filename, "wb") as handle:
pickle.dump(self.tokenizer, handle)
with open(filename, "rb") as handle:
tokenizer_new = pickle.load(handle)
subwords_loaded = tokenizer_new(notes)["token_ids"]
self.assertListEqual(subwords, subwords_loaded)
def test_padding_side_in_kwargs(self):
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", padding_side="left")
self.assertEqual(tokenizer_p.padding_side, "left")
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", padding_side="right")
self.assertEqual(tokenizer_p.padding_side, "right")
self.assertRaises(
ValueError,
Pop2PianoTokenizer.from_pretrained,
"susnato/pop2piano_dev",
padding_side="unauthorized",
)
def test_truncation_side_in_kwargs(self):
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", truncation_side="left")
self.assertEqual(tokenizer_p.truncation_side, "left")
tokenizer_p = Pop2PianoTokenizer.from_pretrained("susnato/pop2piano_dev", truncation_side="right")
self.assertEqual(tokenizer_p.truncation_side, "right")
self.assertRaises(
ValueError,
Pop2PianoTokenizer.from_pretrained,
"susnato/pop2piano_dev",
truncation_side="unauthorized",
)
def test_right_and_left_padding(self):
tokenizer = self.tokenizer
notes = self.get_input_notes()
notes = notes[0]
max_length = 20
padding_idx = tokenizer.pad_token_id
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer.padding_side = "right"
padded_notes = tokenizer(notes, padding="max_length", max_length=max_length)["token_ids"]
padded_notes_length = len(padded_notes)
notes_without_padding = tokenizer(notes, padding="do_not_pad")["token_ids"]
padding_size = max_length - len(notes_without_padding)
self.assertEqual(padded_notes_length, max_length)
self.assertEqual(notes_without_padding + [padding_idx] * padding_size, padded_notes)
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
tokenizer.padding_side = "left"
padded_notes = tokenizer(notes, padding="max_length", max_length=max_length)["token_ids"]
padded_notes_length = len(padded_notes)
notes_without_padding = tokenizer(notes, padding="do_not_pad")["token_ids"]
padding_size = max_length - len(notes_without_padding)
self.assertEqual(padded_notes_length, max_length)
self.assertEqual([padding_idx] * padding_size + notes_without_padding, padded_notes)
# RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding'
notes_without_padding = tokenizer(notes)["token_ids"]
tokenizer.padding_side = "right"
padded_notes_right = tokenizer(notes, padding=False)["token_ids"]
self.assertEqual(len(padded_notes_right), len(notes_without_padding))
self.assertEqual(padded_notes_right, notes_without_padding)
tokenizer.padding_side = "left"
padded_notes_left = tokenizer(notes, padding="longest")["token_ids"]
self.assertEqual(len(padded_notes_left), len(notes_without_padding))
self.assertEqual(padded_notes_left, notes_without_padding)
tokenizer.padding_side = "right"
padded_notes_right = tokenizer(notes, padding="longest")["token_ids"]
self.assertEqual(len(padded_notes_right), len(notes_without_padding))
self.assertEqual(padded_notes_right, notes_without_padding)
tokenizer.padding_side = "left"
padded_notes_left = tokenizer(notes, padding=False)["token_ids"]
self.assertEqual(len(padded_notes_left), len(notes_without_padding))
self.assertEqual(padded_notes_left, notes_without_padding)
def test_right_and_left_truncation(self):
tokenizer = self.tokenizer
notes = self.get_input_notes()
notes = notes[0]
truncation_size = 3
# RIGHT TRUNCATION - Check that it correctly truncates when a maximum length is specified along with the truncation flag set to True
tokenizer.truncation_side = "right"
full_encoded_notes = tokenizer(notes)["token_ids"]
full_encoded_notes_length = len(full_encoded_notes)
truncated_notes = tokenizer(notes, max_length=full_encoded_notes_length - truncation_size, truncation=True)[
"token_ids"
]
self.assertEqual(full_encoded_notes_length, len(truncated_notes) + truncation_size)
self.assertEqual(full_encoded_notes[:-truncation_size], truncated_notes)
# LEFT TRUNCATION - Check that it correctly truncates when a maximum length is specified along with the truncation flag set to True
tokenizer.truncation_side = "left"
full_encoded_notes = tokenizer(notes)["token_ids"]
full_encoded_notes_length = len(full_encoded_notes)
truncated_notes = tokenizer(notes, max_length=full_encoded_notes_length - truncation_size, truncation=True)[
"token_ids"
]
self.assertEqual(full_encoded_notes_length, len(truncated_notes) + truncation_size)
self.assertEqual(full_encoded_notes[truncation_size:], truncated_notes)
# RIGHT & LEFT TRUNCATION - Check that nothing is done for 'longest' and 'no_truncation'
tokenizer.truncation_side = "right"
truncated_notes_right = tokenizer(notes, truncation=True)["token_ids"]
self.assertEqual(full_encoded_notes_length, len(truncated_notes_right))
self.assertEqual(full_encoded_notes, truncated_notes_right)
tokenizer.truncation_side = "left"
truncated_notes_left = tokenizer(notes, truncation="longest_first")["token_ids"]
self.assertEqual(len(truncated_notes_left), full_encoded_notes_length)
self.assertEqual(truncated_notes_left, full_encoded_notes)
tokenizer.truncation_side = "right"
truncated_notes_right = tokenizer(notes, truncation="longest_first")["token_ids"]
self.assertEqual(len(truncated_notes_right), full_encoded_notes_length)
self.assertEqual(truncated_notes_right, full_encoded_notes)
tokenizer.truncation_side = "left"
truncated_notes_left = tokenizer(notes, truncation=True)["token_ids"]
self.assertEqual(len(truncated_notes_left), full_encoded_notes_length)
self.assertEqual(truncated_notes_left, full_encoded_notes)
def test_padding_to_multiple_of(self):
notes = self.get_input_notes()
if self.tokenizer.pad_token is None:
self.skipTest("No padding token.")
else:
normal_tokens = self.tokenizer(notes[0], padding=True, pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
normal_tokens = self.tokenizer(notes[0], pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
# Should also work with truncation
normal_tokens = self.tokenizer(notes[0], padding=True, truncation=True, pad_to_multiple_of=8)
for key, value in normal_tokens.items():
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
# truncation to something which is not a multiple of pad_to_multiple_of raises an error
self.assertRaises(
ValueError,
self.tokenizer.__call__,
notes[0],
padding=True,
truncation=True,
max_length=12,
pad_to_multiple_of=8,
)
def test_padding_with_attention_mask(self):
if self.tokenizer.pad_token is None:
self.skipTest("No padding token.")
if "attention_mask" not in self.tokenizer.model_input_names:
self.skipTest("This model does not use attention mask.")
features = [
{"token_ids": [1, 2, 3, 4, 5, 6], "attention_mask": [1, 1, 1, 1, 1, 0]},
{"token_ids": [1, 2, 3], "attention_mask": [1, 1, 0]},
]
padded_features = self.tokenizer.pad(features)
if self.tokenizer.padding_side == "right":
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [1, 1, 0, 0, 0, 0]])
else:
self.assertListEqual(padded_features["attention_mask"], [[1, 1, 1, 1, 1, 0], [0, 0, 0, 1, 1, 0]])
......@@ -58,6 +58,8 @@ SPECIAL_CASES_TO_ALLOW = {
# used internally in the configuration class file
"LongT5Config": ["feed_forward_proj"],
# used internally in the configuration class file
"Pop2PianoConfig": ["feed_forward_proj"],
# used internally in the configuration class file
"SwitchTransformersConfig": ["feed_forward_proj"],
# having default values other than `1e-5` - we can't fix them without breaking
"BioGptConfig": ["layer_norm_eps"],
......
......@@ -66,6 +66,7 @@ PRIVATE_MODELS = [
"T5Stack",
"MT5Stack",
"UMT5Stack",
"Pop2PianoStack",
"SwitchTransformersStack",
"TFDPRSpanPredictor",
"MaskFormerSwinModel",
......
......@@ -346,6 +346,8 @@ src/transformers/models/poolformer/configuration_poolformer.py
src/transformers/models/poolformer/feature_extraction_poolformer.py
src/transformers/models/poolformer/image_processing_poolformer.py
src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/pop2piano/configuration_pop2piano.py
src/transformers/models/pop2piano/modeling_pop2piano.py
src/transformers/models/prophetnet/tokenization_prophetnet.py
src/transformers/models/rag/tokenization_rag.py
src/transformers/models/realm/configuration_realm.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment