Commit 39ac40a9 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2747 failed with stages
in 0 seconds
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import json
import tarfile
import json
import io
import pyarrow.parquet as pq
from io import BytesIO
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import tarfile
import json
import io
import wave
import numpy as np
import torchaudio
import os
import sys
import json
import random
import pickle
import argparse
import itertools
import mmap
import struct
import collections
import shutil
import multiprocessing as mp
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
import pickle
from wids import wids
import math
torchaudio.set_audio_backend('soundfile')
AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
try:
MAIN_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/mean_embedding.pt")
GPT_SPK_EMBEDDING=torch.load("/workspace/audio_checkpoints/flow_model/spk_embedding/0909/spk_mean_embeddings.pt")
except:
MAIN_SPK_EMBEDDING=torch.zeros(1,192)
GPT_SPK_EMBEDDING=torch.zeros(1,192)
def parquet_opener(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
df = pq.read_table(url).to_pandas()
for i in range(len(df)):
if mode == 'inference' and df.loc[i, 'utt'] not in tts_data:
continue
sample.update(dict(df.loc[i]))
if mode == 'train':
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
else:
for index, text in enumerate(tts_data[df.loc[i, 'utt']]):
yield {**sample, 'tts_index': index, 'tts_text': text}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def parse_tar_header(header_bytes):
header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes)
return TarHeader(*header)
TarHeader = collections.namedtuple(
"TarHeader",
[
"name",
"mode",
"uid",
"gid",
"size",
"mtime",
"chksum",
"typeflag",
"linkname",
"magic",
"version",
"uname",
"gname",
"devmajor",
"devminor",
"prefix",
],
)
class MMTar:
def __init__(self, file_path: Path | str):
self.stream = open(file_path, "rb")
self.mmap = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ)
def __del__(self):
try:
self.mmap.close()
self.stream.close()
except: # noqa
pass
def get_at_offset(self, offset) -> tuple[str, bytes]:
header = parse_tar_header(self.mmap[offset : offset + 500])
name = header.name.decode("utf-8").strip("\x00")
start = offset + 512
end = start + int(header.size.decode("utf-8")[:-1], 8)
return name, self.mmap[start:end]
class Tar:
def __init__(self, path: Path):
self.tar = MMTar(path)
indices_path = path.with_suffix(".index")
self.index = pickle.loads(indices_path.read_bytes())
self.name_mapping = {}
for name, offset, _ in self.index:
self.name_mapping[name] = offset
def read(self, name: str) -> bytes:
return self.tar.get_at_offset(self.name_mapping[name])[1]
def cosy_jsonl_opener(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0907.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_nopool(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-nopool.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
# cosy_data = [json.loads(line) for line in f]
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_pool2(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool2.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_pool4(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool4.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
# cosy_data = [json.loads(line) for line in f]
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def cosy_jsonl_opener_vq0918_pool8(data, mode='train', tts_data={}):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
cosy_jsonl_path = sample['src']
tar_file_path=cosy_jsonl_path.replace(".vq0918-pool8.jsonl",".tar")
try:
tar_data=Tar(Path(tar_file_path))
with open(cosy_jsonl_path, 'r') as f:
# cosy_data = [json.loads(line) for line in f]
for line in f:
item=json.loads(line)
cosy_token = item['cosy_token']
sample['speech_token']=torch.tensor(cosy_token)
sample['speech'], sample['sample_rate']= torchaudio.load(io.BytesIO(tar_data.read(item['filename'])))
# print(item['filename'])
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(cosy_jsonl_path, ex))
def process_sft_vq0918_pool4(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src']
wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool4_split(data, mode='train',split_token=25, tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src']
wav_path=token_npy_path.replace(".vq0918-pool4.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
# sample['speech_token']=torch.tensor(np.load(token_npy_path))
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
# if sample['speech'].shape[0] > 1:
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
speech_token=torch.tensor(np.load(token_npy_path))
speech,sample_rate= torchaudio.load(wav_path)
# split_speech=int(split_token / 12.5 * sample_rate)
if speech.shape[0] > 1:
speech = speech.mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
sample['sample_rate']=sample_rate
num_splits = (speech_token.size(0) + split_token - 1) // split_token
for split_id in range(num_splits):
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
end_speech_idx=int(np.ceil(end_token_idx / 12.5 * sample_rate))
sample['speech_token']=speech_token[:end_token_idx]
sample['speech']=speech[:,:end_speech_idx]
print(sample['speech_token'].size(),sample['speech'].size())
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool2(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src'].replace(".vq0918-pool4.npy",".vq0918-pool2.npy")
wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool2_split(data, mode='train',split_token=50, tts_data={}):
for sample in data:
assert 'src' in sample
token_npy_path = sample['src']
wav_path=token_npy_path.replace(".vq0918-pool2.npy","")
# wav_path,token_npy_path=sample['src'].split(' ')
try:
# sample['speech_token']=torch.tensor(np.load(token_npy_path))
# sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
# if sample['speech'].shape[0] > 1:
# sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
# sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
speech_token=torch.tensor(np.load(token_npy_path))
speech,sample_rate= torchaudio.load(wav_path)
# split_speech=int(split_token / 12.5 * sample_rate)
if speech.shape[0] > 1:
speech = speech.mean(dim=0, keepdim=True)
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
sample['sample_rate']=sample_rate
num_splits = (speech_token.size(0) + split_token - 1) // split_token
for split_id in range(num_splits):
end_token_idx = min((split_id + 1) * split_token, speech_token.size(0))
end_speech_idx=int(np.ceil(end_token_idx / 25 * sample_rate))
sample['speech_token']=speech_token[:end_token_idx]
sample['speech']=speech[:,:end_speech_idx]
print(sample['speech_token'].size(),sample['speech'].size())
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool4_gpt(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
try:
entry=json.loads(sample['src'])
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
for conv in entry["conversations"]:
if "response_wav" in conv:
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=spk_embedding
yield {**sample}
except Exception as ex:
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def process_sft_vq0918_pool4_gpt_1010(data, mode='train', tts_data={}):
for sample in data:
assert 'src' in sample
try:
entry=json.loads(sample['src'])
sample['spk_embedding']=torch.zeros_like(MAIN_SPK_EMBEDDING)
for conv in entry["conversations"]:
if "response_wav" in conv:
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=spk_embedding
yield {**sample}
if "prompt_wav" in conv:
wav_path = f"/workspace/audio_data/sft/{conv['response_wav']}"
token_npy_path=wav_path.replace(".wav",".wav.vq0918-pool4.npy")
sample['speech_token']=torch.tensor(np.load(token_npy_path))
sample['speech'], sample['sample_rate']= torchaudio.load(wav_path)
if sample['speech'].shape[0] > 1:
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
sample['spk_embedding']=spk_embedding
yield {**sample}
except Exception as ex:
# logging.warning('Failed to open {}, ex info {}'.format(wav_path, ex))
logging.warning('Failed to open {}'.format(wav_path))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
# del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['text_token']) < token_min_length:
continue
if len(sample['text_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['text_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def filter_speech_token(data,
max_length=10240,
min_length=10,
token_max_length=5000,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=30,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
# sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
# del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['speech_token']) < token_min_length:
continue
if len(sample['speech_token']) > token_max_length:
continue
if len(sample['speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['speech_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['speech_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def compute_fbank(data,
feat_extractor,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
# assert 'utt' in sample
# assert 'text_token' in sample
waveform = sample['speech']
mat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
sample['speech_feat'] = mat
del sample['speech']
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if mode == 'inference':
sample['tts_text_token'] = tokenizer.encode(sample['tts_text'], allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf:
yield x
buf = []
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'speech_feat' in sample
assert isinstance(sample['speech_feat'], torch.Tensor)
new_sample_frames = sample['speech_feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if mode == 'inference':
return static_batch(data, 1)
else:
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, use_spk_embedding, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
utts = [sample[i]['utt'] for i in order]
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
text = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"utts": utts,
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"text": text,
"text_token": text_token,
"text_token_len": text_token_len,
"utt_embedding": utt_embedding,
"spk_embedding": spk_embedding,
}
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"]
else:
batch["embedding"] = batch["utt_embedding"]
yield batch
def padding_speech_token(data, use_spk_embedding, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
# utts = [sample[i]['utt'] for i in order]
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
try:
speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
batch = {
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
}
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
# if use_spk_embedding is True:
# batch["embedding"] = batch["spk_embedding"]
# else:
# batch["embedding"] = batch["utt_embedding"]
batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
yield batch
except Exception as ex:
logging.warning(' ex info {}'.format(ex))
# assert False
def padding_speech_token_spk(data, use_spk_embedding, mode='train'):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
speech_feat_len = torch.tensor([x['speech_feat'].size(1) for x in sample],
dtype=torch.int32)
order = torch.argsort(speech_feat_len, descending=True)
# utts = [sample[i]['utt'] for i in order]
# speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
try:
speech_token = [sample[i]['speech_token'].clone().detach() for i in order]
speech_token_len = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
speech_token = pad_sequence(speech_token,
batch_first=True,
padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
speech_feat = pad_sequence(speech_feat,
batch_first=True,
padding_value=0)
spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
batch = {
"speech_token": speech_token,
"speech_token_len": speech_token_len,
"speech_feat": speech_feat,
"speech_feat_len": speech_feat_len,
"spk_embedding": spk_embedding,
}
if mode == 'inference':
tts_text = [sample[i]['tts_text'] for i in order]
tts_index = [sample[i]['tts_index'] for i in order]
tts_text_token = [torch.tensor(sample[i]['tts_text_token']) for i in order]
tts_text_token_len = torch.tensor([i.size(0) for i in tts_text_token], dtype=torch.int32)
tts_text_token = pad_sequence(tts_text_token, batch_first=True, padding_value=-1)
batch.update({'tts_text': tts_text,
'tts_index': tts_index,
'tts_text_token': tts_text_token,
'tts_text_token_len': tts_text_token_len})
# if use_spk_embedding is True:
# batch["embedding"] = batch["spk_embedding"]
# else:
# batch["embedding"] = batch["utt_embedding"]
# batch["embedding"]=torch.zeros((batch["speech_feat"].size(0),192),device=batch["speech_feat"].device)
batch["embedding"] = batch["spk_embedding"]
yield batch
except Exception as ex:
logging.warning(' ex info {}'.format(ex))
# assert False
\ No newline at end of file
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from einops import pack, rearrange, repeat
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for i in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
class MaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# embedding=None
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
# print(token.max(),self.input_embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.8 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
option_steps=10):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
feat_len = (token_len / self.input_frame_rate * 22050 / 256).int()
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
if prompt_feat.shape[1] != 0:
for i, j in enumerate(prompt_feat_len):
conds[i, :j] = prompt_feat[i]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=option_steps
)
if prompt_feat.shape[1] != 0:
feat = feat[:, :, prompt_feat.shape[1]:]
return feat
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
class MaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.mel_feat_conf = mel_feat_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# embedding=None
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
# for i, j in enumerate(feat_len):
# if random.random() < 0.5:
# continue
# index = random.randint(0, int(0.3 * j))
# conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
feat_len = (token_len / self.input_frame_rate * 22050 / 256).int()
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros([1, feat_len.max().item(), self.output_size], device=token.device)
if prompt_feat.shape[1] != 0:
for i, j in enumerate(prompt_feat_len):
conds[i, :j] = prompt_feat[i]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
feat = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10
)
if prompt_feat.shape[1] != 0:
feat = feat[:, :, prompt_feat.shape[1]:]
return feat
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
torch.manual_seed(42)
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
for step in range(1, len(t_span)):
dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
cfg_dphi_dt = self.estimator(
x, mask,
torch.zeros_like(mu), t,
torch.zeros_like(spks) if spks is not None else None,
torch.zeros_like(cond)
)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pdb
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
io_channels = 80
input_concat_dim = 80
embed_dim = 768
depth = 24
num_heads = 24
project_cond_tokens = False
transformer_type = "continuous_transformer"
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise torch.Size([1, 80, 621])
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
cfg_dropout_prob = 0.1
cfg_scale = 1.0
# cfg_dropout_prob = 0.0
# cfg_scale = 3.0
for step in range(1, len(t_span)):
# dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
# pdb.set_trace()
dphi_dt = self.estimator(x, # [bs, 80, 229]
t[None], # (bs,)
global_embed=spks,
input_concat_cond=mu,
mask=mask[0], # [bs, 229]
cfg_dropout_prob=cfg_dropout_prob, cfg_scale=cfg_scale)
# Classifier-Free Guidance inference introduced in VoiceBox
if self.inference_cfg_rate > 0:
# cfg_dphi_dt = self.estimator(
# x, mask,
# torch.zeros_like(mu), t,
# torch.zeros_like(spks) if spks is not None else None,
# torch.zeros_like(cond)
# )
cfg_dphi_dt = self.estimator(x, # [bs, 80, 229]
t[None], # (bs,)
global_embed=torch.zeros_like(spks) if spks is not None else None,
input_concat_cond=torch.zeros_like(mu),
mask=mask[0], # [bs, 229]
cfg_dropout_prob=cfg_dropout_prob, cfg_scale=cfg_scale)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1]
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t = 1 - torch.cos(t * 0.5 * torch.pi)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1)
# pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
pred = self.estimator(y, # [bs, 80, 229]
t.squeeze(1, 2), # (bs,)
global_embed=spks,
input_concat_cond=mu,
mask=mask.squeeze(1), # [bs, 229]
cfg_dropout_prob=0.1)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
# def estimator_trans(self):
# pass
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch.nn as nn
from torch.nn import functional as F
from cosyvoice.utils.mask import make_pad_mask
class InterpolateRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
out_channels: int = None,
groups: int = 1,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
for _ in sampling_ratios:
module = nn.Conv1d(channels, channels, 3, 1, 1)
norm = nn.GroupNorm(groups, channels)
act = nn.Mish()
model.extend([module, norm, act])
model.append(
nn.Conv1d(channels, out_channels, 1, 1)
)
self.model = nn.Sequential(*model)
def forward(self, x, ylens=None):
# x in (B, T, D)
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
out = self.model(x).transpose(1, 2).contiguous()
olens = ylens
return out * mask, olens
# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License
# License can be found in LICENSES/LICENSE_ADP.txt
import math
from inspect import isfunction
from math import ceil, floor, log, pi, log2
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from packaging import version
import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many
from torch import Tensor, einsum
from torch.backends.cuda import sdp_kernel
from torch.nn import functional as F
from dac.nn.layers import Snake1d
import pdb
"""
Utils
"""
class ConditionedSequential(nn.Module):
def __init__(self, *modules):
super().__init__()
self.module_list = nn.ModuleList(*modules)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None):
for module in self.module_list:
x = module(x, mapping)
return x
T = TypeVar("T")
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
if exists(val):
return val
return d() if isfunction(d) else d
def exists(val: Optional[T]) -> T:
return val is not None
def closest_power_2(x: float) -> int:
exponent = log2(x)
distance_fn = lambda z: abs(x - 2 ** z) # noqa
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
return 2 ** int(exponent_closest)
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
return_dicts: Tuple[Dict, Dict] = ({}, {})
for key in d.keys():
no_prefix = int(not key.startswith(prefix))
return_dicts[no_prefix][key] = d[key]
return return_dicts
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
if keep_prefix:
return kwargs_with_prefix, kwargs
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs
"""
Convolutional Blocks
"""
import typing as tp
# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License
# License available in LICENSES/LICENSE_META.txt
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`."""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
kernel_size = self.kernel_size[0]
stride = self.stride[0]
dilation = self.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding))
return super().forward(x)
class ConvTranspose1d(nn.ConvTranspose1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: Tensor, causal=False) -> Tensor:
kernel_size = self.kernel_size[0]
stride = self.stride[0]
padding_total = kernel_size - stride
y = super().forward(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if causal:
padding_right = ceil(padding_total)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y
def Downsample1d(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor
)
def Upsample1d(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
if factor == 1:
return Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3
),
)
else:
return ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor
)
class ConvBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
num_groups: int = 8,
use_norm: bool = True,
use_snake: bool = False
) -> None:
super().__init__()
self.groupnorm = (
nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
if use_norm
else nn.Identity()
)
if use_snake:
self.activation = Snake1d(in_channels)
else:
self.activation = nn.SiLU()
self.project = Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)
def forward(
self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False
) -> Tensor:
x = self.groupnorm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.activation(x)
return self.project(x, causal=causal)
class MappingToScaleShift(nn.Module):
def __init__(
self,
features: int,
channels: int,
):
super().__init__()
self.to_scale_shift = nn.Sequential(
nn.SiLU(),
nn.Linear(in_features=features, out_features=channels * 2),
)
def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]:
scale_shift = self.to_scale_shift(mapping)
scale_shift = rearrange(scale_shift, "b c -> b c 1")
scale, shift = scale_shift.chunk(2, dim=1)
return scale, shift
class ResnetBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int = 3,
stride: int = 1,
dilation: int = 1,
use_norm: bool = True,
use_snake: bool = False,
num_groups: int = 8,
context_mapping_features: Optional[int] = None,
) -> None:
super().__init__()
self.use_mapping = exists(context_mapping_features)
self.block1 = ConvBlock1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
if self.use_mapping:
assert exists(context_mapping_features)
self.to_scale_shift = MappingToScaleShift(
features=context_mapping_features, channels=out_channels
)
self.block2 = ConvBlock1d(
in_channels=out_channels,
out_channels=out_channels,
use_norm=use_norm,
num_groups=num_groups,
use_snake=use_snake
)
self.to_out = (
Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
if in_channels != out_channels
else nn.Identity()
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
assert_message = "context mapping required if context_mapping_features > 0"
assert not (self.use_mapping ^ exists(mapping)), assert_message
h = self.block1(x, causal=causal)
scale_shift = None
if self.use_mapping:
scale_shift = self.to_scale_shift(mapping)
h = self.block2(h, scale_shift=scale_shift, causal=causal)
return h + self.to_out(x)
class Patcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False,
):
super().__init__()
assert_message = f"out_channels must be divisible by patch_size ({patch_size})"
assert out_channels % patch_size == 0, assert_message
self.patch_size = patch_size
self.block = ResnetBlock1d(
in_channels=in_channels,
out_channels=out_channels // patch_size,
num_groups=1,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
x = self.block(x, mapping, causal=causal)
x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size)
return x
class Unpatcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int,
context_mapping_features: Optional[int] = None,
use_snake: bool = False
):
super().__init__()
assert_message = f"in_channels must be divisible by patch_size ({patch_size})"
assert in_channels % patch_size == 0, assert_message
self.patch_size = patch_size
self.block = ResnetBlock1d(
in_channels=in_channels // patch_size,
out_channels=out_channels,
num_groups=1,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor:
x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size)
x = self.block(x, mapping, causal=causal)
return x
"""
Attention Components
"""
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
def add_mask(sim: Tensor, mask: Tensor) -> Tensor:
b, ndim = sim.shape[0], mask.ndim
if ndim == 3:
mask = rearrange(mask, "b n m -> b 1 n m")
if ndim == 2:
mask = repeat(mask, "n m -> b 1 n m", b=b)
max_neg_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~mask, max_neg_value)
return sim
def causal_mask(q: Tensor, k: Tensor) -> Tensor:
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1)
mask = repeat(mask, "n m -> b n m", b=b)
return mask
class AttentionBase(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
):
super().__init__()
self.scale = head_features**-0.5
self.num_heads = num_heads
mid_features = head_features * num_heads
out_features = default(out_features, features)
self.to_out = nn.Linear(
in_features=mid_features, out_features=out_features
)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False
) -> Tensor:
# Split heads
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
if not self.use_flash:
if is_causal and not mask:
# Mask out future tokens for causal attention
mask = causal_mask(q, k)
# Compute similarity matrix and add eventual mask
sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale
sim = add_mask(sim, mask) if exists(mask) else sim
# Get attention matrix with softmax
attn = sim.softmax(dim=-1, dtype=torch.float32)
# Compute values
out = einsum("... n m, ... m d -> ... n d", attn, v)
else:
with sdp_kernel(*self.sdp_kernel_config):
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Attention(nn.Module):
def __init__(
self,
features: int,
*,
head_features: int,
num_heads: int,
out_features: Optional[int] = None,
context_features: Optional[int] = None,
causal: bool = False,
):
super().__init__()
self.context_features = context_features
self.causal = causal
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = nn.LayerNorm(features)
self.norm_context = nn.LayerNorm(context_features)
self.to_q = nn.Linear(
in_features=features, out_features=mid_features, bias=False
)
self.to_kv = nn.Linear(
in_features=context_features, out_features=mid_features * 2, bias=False
)
self.attention = AttentionBase(
features,
num_heads=num_heads,
head_features=head_features,
out_features=out_features,
)
def forward(
self,
x: Tensor, # [b, n, c]
context: Optional[Tensor] = None, # [b, m, d]
context_mask: Optional[Tensor] = None, # [b, m], false is masked,
causal: Optional[bool] = False,
) -> Tensor:
assert_message = "You must provide a context when using context_features"
assert not self.context_features or exists(context), assert_message
# Use context if provided
context = default(context, x)
# Normalize then compute q from input and k,v from context
x, context = self.norm(x), self.norm_context(context)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
if exists(context_mask):
# Mask out cross-attention for padding tokens
mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1])
k, v = k * mask, v * mask
# Compute and return attention
return self.attention(q, k, v, is_causal=self.causal or causal)
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
"""
Transformer Blocks
"""
class TransformerBlock(nn.Module):
def __init__(
self,
features: int,
num_heads: int,
head_features: int,
multiplier: int,
context_features: Optional[int] = None,
):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features
)
if self.use_cross_attention:
self.cross_attention = Attention(
features=features,
num_heads=num_heads,
head_features=head_features,
context_features=context_features
)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor:
x = self.attention(x, causal=causal) + x
if self.use_cross_attention:
x = self.cross_attention(x, context=context, context_mask=context_mask) + x
x = self.feed_forward(x) + x
return x
"""
Transformers
"""
class Transformer1d(nn.Module):
def __init__(
self,
num_layers: int,
channels: int,
num_heads: int,
head_features: int,
multiplier: int,
context_features: Optional[int] = None,
):
super().__init__()
self.to_in = nn.Sequential(
nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
),
Rearrange("b c t -> b t c"),
)
self.blocks = nn.ModuleList(
[
TransformerBlock(
features=channels,
head_features=head_features,
num_heads=num_heads,
multiplier=multiplier,
context_features=context_features,
)
for i in range(num_layers)
]
)
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
Conv1d(
in_channels=channels,
out_channels=channels,
kernel_size=1,
),
)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor:
x = self.to_in(x)
for block in self.blocks:
x = block(x, context=context, context_mask=context_mask, causal=causal)
x = self.to_out(x)
return x
"""
Time Embeddings
"""
class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
device, half_dim = x.device, self.dim // 2
emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
nn.Linear(in_features=dim + 1, out_features=out_features),
)
"""
Encoder/Decoder Components
"""
class DownsampleBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
factor: int,
num_groups: int,
num_layers: int,
kernel_multiplier: int = 2,
use_pre_downsample: bool = True,
use_skip: bool = False,
use_snake: bool = False,
extract_channels: int = 0,
context_channels: int = 0,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
super().__init__()
self.use_pre_downsample = use_pre_downsample
self.use_skip = use_skip
self.use_transformer = num_transformer_blocks > 0
self.use_extract = extract_channels > 0
self.use_context = context_channels > 0
channels = out_channels if use_pre_downsample else in_channels
self.downsample = Downsample1d(
in_channels=in_channels,
out_channels=out_channels,
factor=factor,
kernel_multiplier=kernel_multiplier,
)
self.blocks = nn.ModuleList(
[
ResnetBlock1d(
in_channels=channels + context_channels if i == 0 else channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
for i in range(num_layers)
]
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features
)
if self.use_extract:
num_extract_groups = min(num_groups, extract_channels)
self.to_extracted = ResnetBlock1d(
in_channels=out_channels,
out_channels=extract_channels,
num_groups=num_extract_groups,
use_snake=use_snake
)
def forward(
self,
x: Tensor,
*,
mapping: Optional[Tensor] = None,
channels: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Union[Tuple[Tensor, List[Tensor]], Tensor]:
if self.use_pre_downsample:
x = self.downsample(x)
if self.use_context and exists(channels):
x = torch.cat([x, channels], dim=1)
skips = []
for block in self.blocks:
x = block(x, mapping=mapping, causal=causal)
skips += [x] if self.use_skip else []
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
skips += [x] if self.use_skip else []
if not self.use_pre_downsample:
x = self.downsample(x)
if self.use_extract:
extracted = self.to_extracted(x)
return x, extracted
return (x, skips) if self.use_skip else x
class UpsampleBlock1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
*,
factor: int,
num_layers: int,
num_groups: int,
use_nearest: bool = False,
use_pre_upsample: bool = False,
use_skip: bool = False,
use_snake: bool = False,
skip_channels: int = 0,
use_skip_scale: bool = False,
extract_channels: int = 0,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
):
super().__init__()
self.use_extract = extract_channels > 0
self.use_pre_upsample = use_pre_upsample
self.use_transformer = num_transformer_blocks > 0
self.use_skip = use_skip
self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0
channels = out_channels if use_pre_upsample else in_channels
self.blocks = nn.ModuleList(
[
ResnetBlock1d(
in_channels=channels + skip_channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
for _ in range(num_layers)
]
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
)
self.upsample = Upsample1d(
in_channels=in_channels,
out_channels=out_channels,
factor=factor,
use_nearest=use_nearest,
)
if self.use_extract:
num_extract_groups = min(num_groups, extract_channels)
self.to_extracted = ResnetBlock1d(
in_channels=out_channels,
out_channels=extract_channels,
num_groups=num_extract_groups,
use_snake=use_snake
)
def add_skip(self, x: Tensor, skip: Tensor) -> Tensor:
return torch.cat([x, skip * self.skip_scale], dim=1)
def forward(
self,
x: Tensor,
*,
skips: Optional[List[Tensor]] = None,
mapping: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Union[Tuple[Tensor, Tensor], Tensor]:
if self.use_pre_upsample:
x = self.upsample(x)
for block in self.blocks:
x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x
x = block(x, mapping=mapping, causal=causal)
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
if not self.use_pre_upsample:
x = self.upsample(x)
if self.use_extract:
extracted = self.to_extracted(x)
return x, extracted
return x
class BottleneckBlock1d(nn.Module):
def __init__(
self,
channels: int,
*,
num_groups: int,
num_transformer_blocks: int = 0,
attention_heads: Optional[int] = None,
attention_features: Optional[int] = None,
attention_multiplier: Optional[int] = None,
context_mapping_features: Optional[int] = None,
context_embedding_features: Optional[int] = None,
use_snake: bool = False,
):
super().__init__()
self.use_transformer = num_transformer_blocks > 0
self.pre_block = ResnetBlock1d(
in_channels=channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
if self.use_transformer:
assert (
(exists(attention_heads) or exists(attention_features))
and exists(attention_multiplier)
)
if attention_features is None and attention_heads is not None:
attention_features = channels // attention_heads
if attention_heads is None and attention_features is not None:
attention_heads = channels // attention_features
self.transformer = Transformer1d(
num_layers=num_transformer_blocks,
channels=channels,
num_heads=attention_heads,
head_features=attention_features,
multiplier=attention_multiplier,
context_features=context_embedding_features,
)
self.post_block = ResnetBlock1d(
in_channels=channels,
out_channels=channels,
num_groups=num_groups,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def forward(
self,
x: Tensor,
*,
mapping: Optional[Tensor] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False
) -> Tensor:
x = self.pre_block(x, mapping=mapping, causal=causal)
if self.use_transformer:
x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal)
x = self.post_block(x, mapping=mapping, causal=causal)
return x
"""
UNet
"""
class UNet1d(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
attentions: Sequence[int],
patch_size: int = 1,
resnet_groups: int = 8,
use_context_time: bool = True,
kernel_multiplier_downsample: int = 2,
use_nearest_upsample: bool = False,
use_skip_scale: bool = True,
use_snake: bool = False,
use_stft: bool = False,
use_stft_context: bool = False,
out_channels: Optional[int] = None,
context_features: Optional[int] = None,
context_features_multiplier: int = 4,
context_channels: Optional[Sequence[int]] = None,
context_embedding_features: Optional[int] = None,
**kwargs,
):
super().__init__()
out_channels = default(out_channels, in_channels)
context_channels = list(default(context_channels, []))
num_layers = len(multipliers) - 1
use_context_features = exists(context_features)
use_context_channels = len(context_channels) > 0
context_mapping_features = None
attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True)
self.num_layers = num_layers
self.use_context_time = use_context_time
self.use_context_features = use_context_features
self.use_context_channels = use_context_channels
self.use_stft = use_stft
self.use_stft_context = use_stft_context
self.context_features = context_features
context_channels_pad_length = num_layers + 1 - len(context_channels)
context_channels = context_channels + [0] * context_channels_pad_length
self.context_channels = context_channels
self.context_embedding_features = context_embedding_features
if use_context_channels:
has_context = [c > 0 for c in context_channels]
self.has_context = has_context
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
assert (
len(factors) == num_layers
and len(attentions) >= num_layers
and len(num_blocks) == num_layers
)
if use_context_time or use_context_features:
context_mapping_features = channels * context_features_multiplier
self.to_mapping = nn.Sequential(
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features),
nn.GELU(),
)
if use_context_time:
assert exists(context_mapping_features)
self.to_time = nn.Sequential(
TimePositionalEmbedding(
dim=channels, out_features=context_mapping_features
),
nn.GELU(),
)
if use_context_features:
assert exists(context_features) and exists(context_mapping_features)
self.to_features = nn.Sequential(
nn.Linear(
in_features=context_features, out_features=context_mapping_features
),
nn.GELU(),
)
if use_stft:
stft_kwargs, kwargs = groupby("stft_", kwargs)
assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True"
stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2
in_channels *= stft_channels
out_channels *= stft_channels
context_channels[0] *= stft_channels if use_stft_context else 1
assert exists(in_channels) and exists(out_channels)
self.stft = STFT(**stft_kwargs)
assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"
self.to_in = Patcher(
in_channels=in_channels + context_channels[0],
out_channels=channels * multipliers[0],
patch_size=patch_size,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
self.downsamples = nn.ModuleList(
[
DownsampleBlock1d(
in_channels=channels * multipliers[i],
out_channels=channels * multipliers[i + 1],
context_mapping_features=context_mapping_features,
context_channels=context_channels[i + 1],
context_embedding_features=context_embedding_features,
num_layers=num_blocks[i],
factor=factors[i],
kernel_multiplier=kernel_multiplier_downsample,
num_groups=resnet_groups,
use_pre_downsample=True,
use_skip=True,
use_snake=use_snake,
num_transformer_blocks=attentions[i],
**attention_kwargs,
)
for i in range(num_layers)
]
)
self.bottleneck = BottleneckBlock1d(
channels=channels * multipliers[-1],
context_mapping_features=context_mapping_features,
context_embedding_features=context_embedding_features,
num_groups=resnet_groups,
num_transformer_blocks=attentions[-1],
use_snake=use_snake,
**attention_kwargs,
)
self.upsamples = nn.ModuleList(
[
UpsampleBlock1d(
in_channels=channels * multipliers[i + 1],
out_channels=channels * multipliers[i],
context_mapping_features=context_mapping_features,
context_embedding_features=context_embedding_features,
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
factor=factors[i],
use_nearest=use_nearest_upsample,
num_groups=resnet_groups,
use_skip_scale=use_skip_scale,
use_pre_upsample=False,
use_skip=True,
use_snake=use_snake,
skip_channels=channels * multipliers[i + 1],
num_transformer_blocks=attentions[i],
**attention_kwargs,
)
for i in reversed(range(num_layers))
]
)
self.to_out = Unpatcher(
in_channels=channels * multipliers[0],
out_channels=out_channels,
patch_size=patch_size,
context_mapping_features=context_mapping_features,
use_snake=use_snake
)
def get_channels(
self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0
) -> Optional[Tensor]:
"""Gets context channels at `layer` and checks that shape is correct"""
use_context_channels = self.use_context_channels and self.has_context[layer]
if not use_context_channels:
return None
assert exists(channels_list), "Missing context"
# Get channels index (skipping zero channel contexts)
channels_id = self.channels_ids[layer]
# Get channels
channels = channels_list[channels_id]
message = f"Missing context for layer {layer} at index {channels_id}"
assert exists(channels), message
# Check channels
num_channels = self.context_channels[layer]
message = f"Expected context with {num_channels} channels at idx {channels_id}"
assert channels.shape[1] == num_channels, message
# STFT channels if requested
channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
return channels
def get_mapping(
self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
) -> Optional[Tensor]:
"""Combines context time features and features into mapping"""
items, mapping = [], None
# Compute time features
if self.use_context_time:
assert_message = "use_context_time=True but no time features provided"
assert exists(time), assert_message
items += [self.to_time(time)]
# Compute features
if self.use_context_features:
assert_message = "context_features exists but no features provided"
assert exists(features), assert_message
items += [self.to_features(features)]
# Compute joint mapping
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def forward(
self,
x: Tensor,
time: Optional[Tensor] = None,
*,
features: Optional[Tensor] = None,
channels_list: Optional[Sequence[Tensor]] = None,
embedding: Optional[Tensor] = None,
embedding_mask: Optional[Tensor] = None,
causal: Optional[bool] = False,
) -> Tensor:
channels = self.get_channels(channels_list, layer=0)
# Apply stft if required
print(x.shape)
x = self.stft.encode1d(x) if self.use_stft else x # type: ignore
print(x.shape)
# Concat context channels at layer 0 if provided
x = torch.cat([x, channels], dim=1) if exists(channels) else x
print(x.shape)
# Compute mapping from time and features
mapping = self.get_mapping(time, features)
x = self.to_in(x, mapping, causal=causal)
print(x.shape)
skips_list = [x]
for i, downsample in enumerate(self.downsamples):
channels = self.get_channels(channels_list, layer=i + 1)
x, skips = downsample(
x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal
)
skips_list += [skips]
x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
for i, upsample in enumerate(self.upsamples):
skips = skips_list.pop()
x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal)
x += skips_list.pop()
x = self.to_out(x, mapping, causal=causal)
x = self.stft.decode1d(x) if self.use_stft else x
return x
""" Conditioning Modules """
class FixedEmbedding(nn.Module):
def __init__(self, max_length: int, features: int):
super().__init__()
self.max_length = max_length
self.embedding = nn.Embedding(max_length, features)
def forward(self, x: Tensor) -> Tensor:
batch_size, length, device = *x.shape[0:2], x.device
assert_message = "Input sequence length must be <= max_length"
assert length <= self.max_length, assert_message
position = torch.arange(length, device=device)
fixed_embedding = self.embedding(position)
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
return fixed_embedding
def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
if proba == 1:
return torch.ones(shape, device=device, dtype=torch.bool)
elif proba == 0:
return torch.zeros(shape, device=device, dtype=torch.bool)
else:
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
class UNetCFG1d(UNet1d):
"""UNet1d with Classifier-Free Guidance"""
def __init__(
self,
context_embedding_max_length: int,
context_embedding_features: int,
use_xattn_time: bool = False,
**kwargs,
):
super().__init__(
context_embedding_features=context_embedding_features, **kwargs
)
self.use_xattn_time = use_xattn_time
if use_xattn_time:
assert exists(context_embedding_features)
self.to_time_embedding = nn.Sequential(
TimePositionalEmbedding(
dim=kwargs["channels"], out_features=context_embedding_features
),
nn.GELU(),
)
context_embedding_max_length += 1 # Add one for time embedding
self.fixed_embedding = FixedEmbedding(
max_length=context_embedding_max_length, features=context_embedding_features
)
def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
embedding: Tensor,
embedding_mask: Optional[Tensor] = None,
embedding_scale: float = 1.0,
embedding_mask_proba: float = 0.0,
batch_cfg: bool = False,
rescale_cfg: bool = False,
scale_phi: float = 0.4,
negative_embedding: Optional[Tensor] = None,
negative_embedding_mask: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
b, device = embedding.shape[0], embedding.device
if self.use_xattn_time:
embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1)
if embedding_mask is not None:
embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1)
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
# Randomly mask embedding
batch_mask = rand_bool(
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
if batch_cfg:
batch_x = torch.cat([x, x], dim=0)
batch_time = torch.cat([time, time], dim=0)
if negative_embedding is not None:
if negative_embedding_mask is not None:
negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2)
negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding)
batch_embed = torch.cat([embedding, negative_embedding], dim=0)
else:
batch_embed = torch.cat([embedding, fixed_embedding], dim=0)
batch_mask = None
if embedding_mask is not None:
batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0)
batch_features = None
features = kwargs.pop("features", None)
if self.use_context_features:
batch_features = torch.cat([features, features], dim=0)
batch_channels = None
channels_list = kwargs.pop("channels_list", None)
if self.use_context_channels:
batch_channels = []
for channels in channels_list:
batch_channels += [torch.cat([channels, channels], dim=0)]
# Compute both normal and fixed embedding outputs
batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs)
out, out_masked = batch_out.chunk(2, dim=0)
else:
# Compute both normal and fixed embedding outputs
out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs)
out_cfg = out_masked + (out - out_masked) * embedding_scale
if rescale_cfg:
out_std = out.std(dim=1, keepdim=True)
out_cfg_std = out_cfg.std(dim=1, keepdim=True)
return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg
else:
return out_cfg
else:
return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs)
class UNetNCCA1d(UNet1d):
"""UNet1d with Noise Channel Conditioning Augmentation"""
def __init__(self, context_features: int, **kwargs):
super().__init__(context_features=context_features, **kwargs)
self.embedder = NumberEmbedder(features=context_features)
def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
x = x if torch.is_tensor(x) else torch.tensor(x)
return x.expand(shape)
def forward( # type: ignore
self,
x: Tensor,
time: Tensor,
*,
channels_list: Sequence[Tensor],
channels_augmentation: Union[
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
] = False,
channels_scale: Union[
float, Sequence[float], Sequence[Sequence[float]], Tensor
] = 0,
**kwargs,
) -> Tensor:
b, n = x.shape[0], len(channels_list)
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
# Augmentation (for each channel list item)
for i in range(n):
scale = channels_scale[:, i] * channels_augmentation[:, i]
scale = rearrange(scale, "b -> b 1 1")
item = channels_list[i]
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
# Scale embedding (sum reduction if more than one channel list item)
channels_scale_emb = self.embedder(channels_scale)
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
return super().forward(
x=x,
time=time,
channels_list=channels_list,
features=channels_scale_emb,
**kwargs,
)
class UNetAll1d(UNetCFG1d, UNetNCCA1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs): # type: ignore
return UNetCFG1d.forward(self, *args, **kwargs)
def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
if type == "base":
return UNet1d(**kwargs)
elif type == "all":
return UNetAll1d(**kwargs)
elif type == "cfg":
return UNetCFG1d(**kwargs)
elif type == "ncca":
return UNetNCCA1d(**kwargs)
else:
raise ValueError(f"Unknown XUNet1d type: {type}")
class NumberEmbedder(nn.Module):
def __init__(
self,
features: int,
dim: int = 256,
):
super().__init__()
self.features = features
self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
def forward(self, x: Union[List[float], Tensor]) -> Tensor:
if not torch.is_tensor(x):
device = next(self.embedding.parameters()).device
x = torch.tensor(x, device=device)
assert isinstance(x, Tensor)
shape = x.shape
x = rearrange(x, "... -> (...)")
embedding = self.embedding(x)
x = embedding.view(*shape, self.features)
return x # type: ignore
"""
Audio Transforms
"""
class STFT(nn.Module):
"""Helper for torch stft and istft"""
def __init__(
self,
num_fft: int = 1023,
hop_length: int = 256,
window_length: Optional[int] = None,
length: Optional[int] = None,
use_complex: bool = False,
):
super().__init__()
self.num_fft = num_fft
self.hop_length = default(hop_length, floor(num_fft // 4))
self.window_length = default(window_length, num_fft)
self.length = length
self.register_buffer("window", torch.hann_window(self.window_length))
self.use_complex = use_complex
def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
b = wave.shape[0]
wave = rearrange(wave, "b c t -> (b c) t")
stft = torch.stft(
wave,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
return_complex=True,
normalized=True,
)
if self.use_complex:
# Returns real and imaginary
stft_a, stft_b = stft.real, stft.imag
else:
# Returns magnitude and phase matrices
magnitude, phase = torch.abs(stft), torch.angle(stft)
stft_a, stft_b = magnitude, phase
return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b)
def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor:
b, l = stft_a.shape[0], stft_a.shape[-1] # noqa
length = closest_power_2(l * self.hop_length)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l")
if self.use_complex:
real, imag = stft_a, stft_b
else:
magnitude, phase = stft_a, stft_b
real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase)
stft = torch.stack([real, imag], dim=-1)
wave = torch.istft(
stft,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
length=default(self.length, length),
normalized=True,
)
return rearrange(wave, "(b c) t -> b c t", b=b)
def encode1d(
self, wave: Tensor, stacked: bool = True
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
stft_a, stft_b = self.encode(wave)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l")
return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b)
def decode1d(self, stft_pair: Tensor) -> Tensor:
f = self.num_fft // 2 + 1
stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1)
stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f)
return self.decode(stft_a, stft_b)
from functools import reduce
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.backends.cuda import sdp_kernel
from packaging import version
from dac.nn.layers import Snake1d
class ResidualBlock(nn.Module):
def __init__(self, main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return self.main(input) + self.skip(input)
class ResConvBlock(ResidualBlock):
def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
super().__init__([
nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_mid),
Snake1d(c_mid) if use_snake else nn.GELU(),
nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
(Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
], skip)
class SelfAttention1d(nn.Module):
def __init__(self, c_in, n_head=1, dropout_rate=0.):
super().__init__()
assert c_in % n_head == 0
self.norm = nn.GroupNorm(1, c_in)
self.n_head = n_head
self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
self.out_proj = nn.Conv1d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)
self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
if not self.use_flash:
return
device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
if device_properties.major == 8 and device_properties.minor == 0:
# Use flash attention for A100 GPUs
self.sdp_kernel_config = (True, False, False)
else:
# Don't use flash attention for other GPUs
self.sdp_kernel_config = (False, True, True)
def forward(self, input):
n, c, s = input.shape
qkv = self.qkv_proj(self.norm(input))
qkv = qkv.view(
[n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
q, k, v = qkv.chunk(3, dim=1)
scale = k.shape[3]**-0.25
if self.use_flash:
with sdp_kernel(*self.sdp_kernel_config):
y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
else:
att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
return input + self.dropout(self.out_proj(y))
class SkipBlock(nn.Module):
def __init__(self, *main):
super().__init__()
self.main = nn.Sequential(*main)
def forward(self, input):
return torch.cat([self.main(input), input], dim=1)
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1.):
super().__init__()
assert out_features % 2 == 0
self.weight = nn.Parameter(torch.randn(
[out_features // 2, in_features]) * std)
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return torch.cat([f.cos(), f.sin()], dim=-1)
def expand_to_planes(input, shape):
return input[..., None].repeat([1, 1, shape[2]])
_kernels = {
'linear':
[1 / 8, 3 / 8, 3 / 8, 1 / 8],
'cubic':
[-0.01171875, -0.03515625, 0.11328125, 0.43359375,
0.43359375, 0.11328125, -0.03515625, -0.01171875],
'lanczos3':
[0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
-0.066637322306633, 0.13550527393817902, 0.44638532400131226,
0.44638532400131226, 0.13550527393817902, -0.066637322306633,
-0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
}
class Downsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel])
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, (self.pad,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv1d(x, weight, stride=2)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
class Upsample1d(nn.Module):
def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
super().__init__()
self.pad_mode = pad_mode
kernel_1d = torch.tensor(_kernels[kernel]) * 2
self.pad = kernel_1d.shape[0] // 2 - 1
self.register_buffer('kernel', kernel_1d)
self.channels_last = channels_last
def forward(self, x):
if self.channels_last:
x = x.permute(0, 2, 1)
x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
if self.channels_last:
x = x.permute(0, 2, 1)
return x
def Downsample1d_2(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
return nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * kernel_multiplier + 1,
stride=factor,
padding=factor * (kernel_multiplier // 2),
)
def Upsample1d_2(
in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
) -> nn.Module:
if factor == 1:
return nn.Conv1d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
)
if use_nearest:
return nn.Sequential(
nn.Upsample(scale_factor=factor, mode="nearest"),
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
),
)
else:
return nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=factor * 2,
stride=factor,
padding=factor // 2 + factor % 2,
output_padding=factor % 2,
)
def zero_init(layer):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
return layer
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
#rms_norm = torch.compile(rms_norm)
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()
self.eps = eps
self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
def extra_repr(self):
return f"eps={self.eps},"
def forward(self, x, cond):
return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
def normalize(x, eps=1e-4):
dim = list(range(1, x.ndim))
n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
alpha = np.sqrt(n.numel() / x.numel())
return x / torch.add(eps, n, alpha=alpha)
class ForcedWNConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1):
super().__init__()
self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
def forward(self, x):
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(self.weight))
fan_in = self.weight[0].numel()
w = normalize(self.weight) / math.sqrt(fan_in)
return F.conv1d(x, w, padding='same')
# Kernels
use_compile = True
def compile(function, *args, **kwargs):
if not use_compile:
return function
try:
return torch.compile(function, *args, **kwargs)
except RuntimeError:
return function
@compile
def linear_geglu(x, weight, bias=None):
x = x @ weight.mT
if bias is not None:
x = x + bias
x, gate = x.chunk(2, dim=-1)
return x * F.gelu(gate)
@compile
def rms_norm(x, scale, eps):
dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
# Layers
class LinearGEGLU(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features * 2, bias=bias)
self.out_features = out_features
def forward(self, x):
return linear_geglu(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, shape, fix_scale = False, eps=1e-6):
super().__init__()
self.eps = eps
if fix_scale:
self.register_buffer("scale", torch.ones(shape))
else:
self.scale = nn.Parameter(torch.ones(shape))
def extra_repr(self):
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
def forward(self, x):
return rms_norm(x, self.scale, self.eps)
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
# try:
# snake_beta = torch.compile(snake_beta)
# except RuntimeError:
# pass
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
# License available in LICENSES/LICENSE_NVIDIA.txt
class SnakeBeta(nn.Module):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale: # log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else: # linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
\ No newline at end of file
import typing as tp
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from x_transformers import ContinuousTransformerWrapper, Encoder
from .blocks import FourierFeatures
from .transformer import ContinuousTransformer
from .transformer_use_mask import ContinuousTransformer as ContinuousTransformer_mask
class DiffusionTransformer(nn.Module):
def __init__(self,
io_channels=32,
patch_size=1,
embed_dim=768,
cond_token_dim=0,
project_cond_tokens=True,
global_cond_dim=0,
project_global_cond=True,
input_concat_dim=0,
prepend_cond_dim=0,
depth=12,
num_heads=8,
transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
**kwargs):
super().__init__()
self.cond_token_dim = cond_token_dim
# Timestep embeddings
timestep_features_dim = 256
self.timestep_features = FourierFeatures(1, timestep_features_dim)
self.to_timestep_embed = nn.Sequential(
nn.Linear(timestep_features_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
if cond_token_dim > 0:
# Conditioning tokens
cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
self.to_cond_embed = nn.Sequential(
nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
)
else:
cond_embed_dim = 0
self.to_cond_embed = nn.Identity()
if global_cond_dim > 0:
# Global conditioning
global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
self.to_global_embed = nn.Sequential(
nn.Linear(global_cond_dim, global_embed_dim, bias=False),
nn.SiLU(),
nn.Linear(global_embed_dim, global_embed_dim, bias=False)
)
if prepend_cond_dim > 0:
# Prepend conditioning
self.to_prepend_embed = nn.Sequential(
nn.Linear(prepend_cond_dim, embed_dim, bias=False),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=False)
)
self.input_concat_dim = input_concat_dim
dim_in = io_channels + self.input_concat_dim
self.patch_size = patch_size
# Transformer
self.transformer_type = transformer_type
self.global_cond_type = global_cond_type
if self.transformer_type == "x-transformers":
self.transformer = ContinuousTransformerWrapper(
dim_in=dim_in * patch_size,
dim_out=io_channels * patch_size,
max_seq_len=0, # Not relevant without absolute positional embeds
attn_layers=Encoder(
dim=embed_dim,
depth=depth,
heads=num_heads,
attn_flash=True,
cross_attend=cond_token_dim > 0,
dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
zero_init_branch_output=True,
use_abs_pos_emb=False,
rotary_pos_emb=True,
ff_swish=True,
ff_glu=True,
**kwargs
)
)
elif self.transformer_type == "continuous_transformer":
global_dim = None
if self.global_cond_type == "adaLN":
# The global conditioning is projected to the embed_dim already at this point
global_dim = embed_dim
self.transformer = ContinuousTransformer(
dim=embed_dim,
depth=depth,
dim_heads=embed_dim // num_heads,
dim_in=dim_in * patch_size,
dim_out=io_channels * patch_size,
cross_attend=cond_token_dim > 0,
cond_token_dim=cond_embed_dim,
global_cond_dim=global_dim,
**kwargs
)
elif self.transformer_type == "continuous_transformer_with_mask":
global_dim = None
if self.global_cond_type == "adaLN":
# The global conditioning is projected to the embed_dim already at this point
global_dim = embed_dim
self.transformer = ContinuousTransformer_mask(
dim=embed_dim,
depth=depth,
dim_heads=embed_dim // num_heads,
dim_in=dim_in * patch_size,
dim_out=io_channels * patch_size,
cross_attend=cond_token_dim > 0,
cond_token_dim=cond_embed_dim,
global_cond_dim=global_dim,
**kwargs
)
else:
raise ValueError(f"Unknown transformer type: {self.transformer_type}")
self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
nn.init.zeros_(self.preprocess_conv.weight)
self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
nn.init.zeros_(self.postprocess_conv.weight)
def _forward(
self,
x,
t,
mask=None,
cross_attn_cond=None,
cross_attn_cond_mask=None,
input_concat_cond=None,
global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
return_info=False,
**kwargs):
### 1. 需要重新写过以适应不同长度的con
if cross_attn_cond is not None:
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
if global_embed is not None:
# Project the global conditioning to the embedding dimension
global_embed = self.to_global_embed(global_embed)
prepend_inputs = None
prepend_mask = None
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
if input_concat_cond.shape[2] != x.shape[2]:
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest')
x = torch.cat([x, input_concat_cond], dim=1)
# Get the batch of timestep embeddings
try:
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
except Exception as e:
print("t.shape:", t.shape, "x.shape", x.shape)
print("t:", t)
raise e
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if global_embed is not None:
global_embed = global_embed + timestep_embed
else:
global_embed = timestep_embed
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
if self.global_cond_type == "prepend":
if prepend_inputs is None:
# Prepend inputs are just the global embed, and the mask is all ones
prepend_inputs = global_embed.unsqueeze(1)
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
else:
# Prepend inputs are the prepend conditioning + the global embed
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)],
dim=1)
prepend_length = prepend_inputs.shape[1]
x = self.preprocess_conv(x) + x
x = rearrange(x, "b c t -> b t c")
extra_args = {}
if self.global_cond_type == "adaLN":
extra_args["global_cond"] = global_embed
if self.patch_size > 1:
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
if self.transformer_type == "x-transformers":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
**extra_args, **kwargs)
elif self.transformer_type in ["continuous_transformer","continuous_transformer_with_mask"] :
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
return_info=return_info, **extra_args, **kwargs)
if return_info:
output, info = output
elif self.transformer_type == "mm_transformer":
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask,
**extra_args, **kwargs)
output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:]
if self.patch_size > 1:
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
output = self.postprocess_conv(output) + output
if return_info:
return output, info
return output
def forward(
self,
x,
t,
cross_attn_cond=None,
cross_attn_cond_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob=0.0,
causal=False,
scale_phi=0.0,
mask=None,
return_info=False,
**kwargs):
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
if cross_attn_cond_mask is not None:
cross_attn_cond_mask = cross_attn_cond_mask.bool()
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
if prepend_cond_mask is not None:
prepend_cond_mask = prepend_cond_mask.bool()
# CFG dropout
if cfg_dropout_prob > 0.0:
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
dropout_mask = torch.bernoulli(
torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(
torch.bool)
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
dropout_mask = torch.bernoulli(
torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(
torch.bool)
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
# Classifier-free guidance
# Concatenate conditioned and unconditioned inputs on the batch dimension
batch_inputs = torch.cat([x, x], dim=0)
batch_timestep = torch.cat([t, t], dim=0)
if global_embed is not None:
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
else:
batch_global_cond = None
if input_concat_cond is not None:
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
else:
batch_input_concat_cond = None
batch_cond = None
batch_cond_masks = None
# Handle CFG for cross-attention conditioning
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
if negative_cross_attn_cond is not None:
# If there's a negative cross-attention mask, set the masked tokens to the null embed
if negative_cross_attn_mask is not None:
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond,
null_embed)
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
else:
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
if cross_attn_cond_mask is not None:
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
batch_prepend_cond = None
batch_prepend_cond_mask = None
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
if prepend_cond_mask is not None:
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
if mask is not None:
batch_masks = torch.cat([mask, mask], dim=0)
else:
batch_masks = None
batch_output = self._forward(
batch_inputs,
batch_timestep,
cross_attn_cond=batch_cond,
cross_attn_cond_mask=batch_cond_masks,
mask=batch_masks,
input_concat_cond=batch_input_concat_cond,
global_embed=batch_global_cond,
prepend_cond=batch_prepend_cond,
prepend_cond_mask=batch_prepend_cond_mask,
return_info=return_info,
**kwargs)
if return_info:
batch_output, info = batch_output
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
# CFG Rescale
if scale_phi != 0.0:
cond_out_std = cond_output.std(dim=1, keepdim=True)
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output
else:
output = cfg_output
if return_info:
return output, info
return output
else:
return self._forward(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_cond_mask,
input_concat_cond=input_concat_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
mask=mask,
return_info=return_info,
**kwargs
)
import typing as tp
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from x_transformers import ContinuousTransformerWrapper, Encoder
from .blocks import FourierFeatures
from .transformer import ContinuousTransformer
from model.stable import transformer_use_mask
class DiffusionTransformerV2(nn.Module):
def __init__(self,
io_channels=32,
patch_size=1,
embed_dim=768,
cond_token_dim=0,
project_cond_tokens=True,
global_cond_dim=0,
project_global_cond=True,
input_concat_dim=0,
prepend_cond_dim=0,
depth=12,
num_heads=8,
transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers",
global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
**kwargs):
super().__init__()
d_model = embed_dim
n_head = num_heads
n_layers = depth
encoder_layer = torch.nn.TransformerEncoderLayer(batch_first=True,
norm_first=True,
d_model=d_model,
nhead=n_head)
self.transformer = torch.nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
# ===================================== timestep embedding
timestep_features_dim = 256
self.timestep_features = FourierFeatures(1, timestep_features_dim)
self.to_timestep_embed = nn.Sequential(
nn.Linear(timestep_features_dim, embed_dim, bias=True),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim, bias=True),
)
def _forward(
self,
Xt_btd,
t, #(1d)
mu_btd,
):
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
cated_input = torch.cat([t,mu,x_t])
### 1. 需要重新写过以适应不同长度的con
if cross_attn_cond is not None:
cross_attn_cond = self.to_cond_embed(cross_attn_cond)
if global_embed is not None:
# Project the global conditioning to the embedding dimension
global_embed = self.to_global_embed(global_embed)
prepend_inputs = None
prepend_mask = None
prepend_length = 0
if prepend_cond is not None:
# Project the prepend conditioning to the embedding dimension
prepend_cond = self.to_prepend_embed(prepend_cond)
prepend_inputs = prepend_cond
if prepend_cond_mask is not None:
prepend_mask = prepend_cond_mask
if input_concat_cond is not None:
# Interpolate input_concat_cond to the same length as x
if input_concat_cond.shape[2] != x.shape[2]:
input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2],), mode='nearest')
x = torch.cat([x, input_concat_cond], dim=1)
# Get the batch of timestep embeddings
try:
timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
except Exception as e:
print("t.shape:", t.shape, "x.shape", x.shape)
print("t:", t)
raise e
# Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
if global_embed is not None:
global_embed = global_embed + timestep_embed
else:
global_embed = timestep_embed
# Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
if self.global_cond_type == "prepend":
if prepend_inputs is None:
# Prepend inputs are just the global embed, and the mask is all ones
prepend_inputs = global_embed.unsqueeze(1)
prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
else:
# Prepend inputs are the prepend conditioning + the global embed
prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)],
dim=1)
prepend_length = prepend_inputs.shape[1]
x = self.preprocess_conv(x) + x
x = rearrange(x, "b c t -> b t c")
extra_args = {}
if self.global_cond_type == "adaLN":
extra_args["global_cond"] = global_embed
if self.patch_size > 1:
x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
if self.transformer_type == "x-transformers":
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
**extra_args, **kwargs)
elif self.transformer_type in ["continuous_transformer", "continuous_transformer_with_mask"]:
output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond,
context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask,
return_info=return_info, **extra_args, **kwargs)
if return_info:
output, info = output
elif self.transformer_type == "mm_transformer":
output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask,
**extra_args, **kwargs)
output = rearrange(output, "b t c -> b c t")[:, :, prepend_length:]
if self.patch_size > 1:
output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
output = self.postprocess_conv(output) + output
if return_info:
return output, info
return output
def forward(
self,
x,
t,
cross_attn_cond=None,
cross_attn_cond_mask=None,
negative_cross_attn_cond=None,
negative_cross_attn_mask=None,
input_concat_cond=None,
global_embed=None,
negative_global_embed=None,
prepend_cond=None,
prepend_cond_mask=None,
cfg_scale=1.0,
cfg_dropout_prob=0.0,
causal=False,
scale_phi=0.0,
mask=None,
return_info=False,
**kwargs):
assert causal == False, "Causal mode is not supported for DiffusionTransformer"
if cross_attn_cond_mask is not None:
cross_attn_cond_mask = cross_attn_cond_mask.bool()
cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
if prepend_cond_mask is not None:
prepend_cond_mask = prepend_cond_mask.bool()
# CFG dropout
if cfg_dropout_prob > 0.0:
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
dropout_mask = torch.bernoulli(
torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(
torch.bool)
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
dropout_mask = torch.bernoulli(
torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(
torch.bool)
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None):
# Classifier-free guidance
# Concatenate conditioned and unconditioned inputs on the batch dimension
batch_inputs = torch.cat([x, x], dim=0)
batch_timestep = torch.cat([t, t], dim=0)
if global_embed is not None:
batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
else:
batch_global_cond = None
if input_concat_cond is not None:
batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
else:
batch_input_concat_cond = None
batch_cond = None
batch_cond_masks = None
# Handle CFG for cross-attention conditioning
if cross_attn_cond is not None:
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
# For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
if negative_cross_attn_cond is not None:
# If there's a negative cross-attention mask, set the masked tokens to the null embed
if negative_cross_attn_mask is not None:
negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond,
null_embed)
batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
else:
batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
if cross_attn_cond_mask is not None:
batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
batch_prepend_cond = None
batch_prepend_cond_mask = None
if prepend_cond is not None:
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
if prepend_cond_mask is not None:
batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
if mask is not None:
batch_masks = torch.cat([mask, mask], dim=0)
else:
batch_masks = None
batch_output = self._forward(
batch_inputs,
batch_timestep,
cross_attn_cond=batch_cond,
cross_attn_cond_mask=batch_cond_masks,
mask=batch_masks,
input_concat_cond=batch_input_concat_cond,
global_embed=batch_global_cond,
prepend_cond=batch_prepend_cond,
prepend_cond_mask=batch_prepend_cond_mask,
return_info=return_info,
**kwargs)
if return_info:
batch_output, info = batch_output
cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
# CFG Rescale
if scale_phi != 0.0:
cond_out_std = cond_output.std(dim=1, keepdim=True)
out_cfg_std = cfg_output.std(dim=1, keepdim=True)
output = scale_phi * (cfg_output * (cond_out_std / out_cfg_std)) + (1 - scale_phi) * cfg_output
else:
output = cfg_output
if return_info:
return output, info
return output
else:
return self._forward(
x,
t,
cross_attn_cond=cross_attn_cond,
cross_attn_cond_mask=cross_attn_cond_mask,
input_concat_cond=input_concat_cond,
global_embed=global_embed,
prepend_cond=prepend_cond,
prepend_cond_mask=prepend_cond_mask,
mask=mask,
return_info=return_info,
**kwargs
)
import torch
import math
from tqdm import trange, tqdm
import k_diffusion as K
# Define the noise schedule and sampling loop
def get_alphas_sigmas(t):
"""Returns the scaling factors for the clean image (alpha) and for the
noise (sigma), given a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2
def t_to_alpha_sigma(t):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
@torch.no_grad()
def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
"""Draws samples from a model given starting noise. Euler method"""
# Make tensor of ones to broadcast the single t values
ts = x.new_ones([x.shape[0]])
# Create the noise schedule
t = torch.linspace(sigma_max, 0, steps + 1)
#alphas, sigmas = 1-t, t
for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
# Broadcast the current timestep to the correct shape
t_curr_tensor = t_curr * torch.ones(
(x.shape[0],), dtype=x.dtype, device=x.device
)
dt = t_prev - t_curr # we solve backwards in our formulation
x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
# If we are on the last timestep, output the denoised image
return x
@torch.no_grad()
def sample(model, x, steps, eta, **extra_args):
"""Draws samples from a model given starting noise. v-diffusion"""
ts = x.new_ones([x.shape[0]])
# Create the noise schedule
t = torch.linspace(1, 0, steps + 1)[:-1]
alphas, sigmas = get_alphas_sigmas(t)
# The sampling loop
for i in trange(steps):
# Get the model output (v, the predicted velocity)
with torch.cuda.amp.autocast():
v = model(x, ts * t[i], **extra_args).float()
# Predict the noise and the denoised image
pred = x * alphas[i] - v * sigmas[i]
eps = x * sigmas[i] + v * alphas[i]
# If we are not on the last timestep, compute the noisy image for the
# next timestep.
if i < steps - 1:
# If eta > 0, adjust the scaling factor for the predicted noise
# downward according to the amount of additional noise to add
ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
(1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
# Recombine the predicted noise and predicted denoised image in the
# correct proportions for the next step
x = pred * alphas[i + 1] + eps * adjusted_sigma
# Add the correct amount of fresh noise
if eta:
x += torch.randn_like(x) * ddim_sigma
# If we are on the last timestep, output the denoised image
return pred
# Soft mask inpainting is just shrinking hard (binary) mask inpainting
# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
def get_bmask(i, steps, mask):
strength = (i+1)/(steps)
# convert to binary mask
bmask = torch.where(mask<=strength,1,0)
return bmask
def make_cond_model_fn(model, cond_fn):
def cond_model_fn(x, sigma, **kwargs):
with torch.enable_grad():
x = x.detach().requires_grad_()
denoised = model(x, sigma, **kwargs)
cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
return cond_denoised
return cond_model_fn
# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
# init_data is init_audio as latents (if this is latent diffusion)
# For sampling, set both init_data and mask to None
# For variations, set init_data
# For inpainting, set both init_data & mask
def sample_k(
model_fn,
noise,
init_data=None,
mask=None,
steps=100,
sampler_type="dpmpp-2m-sde",
sigma_min=0.5,
sigma_max=50,
rho=1.0, device="cuda",
callback=None,
cond_fn=None,
**extra_args
):
denoiser = K.external.VDenoiser(model_fn)
if cond_fn is not None:
denoiser = make_cond_model_fn(denoiser, cond_fn)
# Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
# Scale the initial noise by sigma
noise = noise * sigmas[0]
wrapped_callback = callback
if mask is None and init_data is not None:
# VARIATION (no inpainting)
# set the initial latent to the init_data, and noise it with initial sigma
x = init_data + noise
elif mask is not None and init_data is not None:
# INPAINTING
bmask = get_bmask(0, steps, mask)
# initial noising
input_noised = init_data + noise
# set the initial latent to a mix of init_data and noise, based on step 0's binary mask
x = input_noised * bmask + noise * (1-bmask)
# define the inpainting callback function (Note: side effects, it mutates x)
# See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
# callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
# This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
def inpainting_callback(args):
i = args["i"]
x = args["x"]
sigma = args["sigma"]
#denoised = args["denoised"]
# noise the init_data input with this step's appropriate amount of noise
input_noised = init_data + torch.randn_like(init_data) * sigma
# shrinking hard mask
bmask = get_bmask(i, steps, mask)
# mix input_noise with x, using binary mask
new_x = input_noised * bmask + x * (1-bmask)
# mutate x
x[:,:,:] = new_x[:,:,:]
# wrap together the inpainting callback and the user-submitted callback.
if callback is None:
wrapped_callback = inpainting_callback
else:
wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
else:
# SAMPLING
# set the initial latent to noise
x = noise
with torch.cuda.amp.autocast():
if sampler_type == "k-heun":
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "k-lms":
return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "k-dpmpp-2s-ancestral":
return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "k-dpm-2":
return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "k-dpm-fast":
return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "k-dpm-adaptive":
return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "dpmpp-2m-sde":
return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "dpmpp-3m-sde":
return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
# Uses discrete Euler sampling for rectified flow models
# init_data is init_audio as latents (if this is latent diffusion)
# For sampling, set both init_data and mask to None
# For variations, set init_data
# For inpainting, set both init_data & mask
def sample_rf(
model_fn,
noise,
init_data=None,
steps=100,
sigma_max=1,
device="cuda",
callback=None,
cond_fn=None,
**extra_args
):
if sigma_max > 1:
sigma_max = 1
if cond_fn is not None:
denoiser = make_cond_model_fn(denoiser, cond_fn)
wrapped_callback = callback
if init_data is not None:
# VARIATION (no inpainting)
# Interpolate the init data and the noise for init audio
x = init_data * (1 - sigma_max) + noise * sigma_max
else:
# SAMPLING
# set the initial latent to noise
x = noise
with torch.cuda.amp.autocast():
# TODO: Add callback support
#return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
\ No newline at end of file
import torch
from torch.nn import functional as F
from .dit import DiffusionTransformer
from .adp import UNet1d
from .sampling import sample
import math
from model.base import BaseModule
import pdb
target_length = 1536
def pad_and_create_mask(matrix, target_length):
T = matrix.shape[2]
if T > target_length:
raise ValueError("The third dimension length %s should not exceed %s" % (T, target_length))
padding_size = target_length - T
padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0)
mask = torch.ones((1, target_length))
mask[:, T:] = 0 # Set the padding part to 0
return padded_matrix.to(matrix.device), mask.to(matrix.device)
class Stable_Diffusion(BaseModule):
def __init__(self, io_channels, input_concat_dim=None, embed_dim=768, depth=24, num_heads=24,
project_cond_tokens=False, transformer_type="continuous_transformer"):
super(Stable_Diffusion, self).__init__()
self.diffusion = DiffusionTransformer(
io_channels=io_channels,
input_concat_dim=input_concat_dim,
embed_dim=embed_dim,
# cond_token_dim=target_length,
depth=depth,
num_heads=num_heads,
project_cond_tokens=project_cond_tokens,
transformer_type=transformer_type,
)
# self.diffusion = UNet1d(
# in_channels=80,
# channels=256,
# resnet_groups=16,
# kernel_multiplier_downsample=2,
# multipliers=[4, 4, 4, 5, 5],
# factors=[1, 2, 2, 4], # 输入长度不一致卷积缩短
# num_blocks=[2, 2, 2, 2],
# attentions=[1, 3, 3, 3, 3],
# attention_heads=16,
# attention_multiplier=4,
# use_nearest_upsample=False,
# use_skip_scale=True,
# use_context_time=True
# )
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
@torch.no_grad()
def forward(self, mu, mask, n_timesteps):
# pdb.set_trace()
mask = mask.squeeze(1)
noise = torch.randn_like(mu).to(mu.device)
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
# extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask}
extra_args = {"input_concat_cond": mu, "mask": mask}
fakes = sample(self.diffusion, noise, n_timesteps, 0, **extra_args)
return fakes
def compute_loss(self, x0, mask, mu):
# pdb.set_trace()
t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device)
alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
alphas = alphas[:, None, None]
sigmas = sigmas[:, None, None]
noise = torch.randn_like(x0)
noised_inputs = x0 * alphas + noise * sigmas
targets = noise * alphas - x0 * sigmas
mask = mask.squeeze(1)
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
# output = self.diffusion(noised_inputs, t, cross_attn_cond=mu,
# cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1)
# pdb.set_trace()
output = self.diffusion(noised_inputs, # [bs, 80, 229]
t, # (bs,)
input_concat_cond=mu,
mask=mask, # [bs, 229]
cfg_dropout_prob=0.1)
return self.mse_loss(output, targets, mask), output
def mse_loss(self, output, targets, mask):
mse_loss = F.mse_loss(output, targets, reduction='none')
if mask.ndim == 2 and mse_loss.ndim == 3:
mask = mask.unsqueeze(1)
if mask.shape[1] != mse_loss.shape[1]:
mask = mask.repeat(1, mse_loss.shape[1], 1)
mse_loss = mse_loss * mask
mse_loss = mse_loss.mean()
return mse_loss
import torch
from torch.nn import functional as F
from .dit import DiffusionTransformer
from .adp import UNet1d
from .sampling import sample
import math
from model.base import BaseModule
import pdb
target_length = 1536
def pad_and_create_mask(matrix, target_length):
T = matrix.shape[2]
if T > target_length:
raise ValueError("The third dimension length %s should not exceed %s"%(T, target_length))
padding_size = target_length - T
padded_matrix = F.pad(matrix, (0, padding_size), "constant", 0)
mask = torch.ones((1, target_length))
mask[:, T:] = 0 # Set the padding part to 0
return padded_matrix.to(matrix.device), mask.to(matrix.device)
class Stable_Diffusion(BaseModule):
def __init__(self):
super(Stable_Diffusion, self).__init__()
self.diffusion = DiffusionTransformer(
io_channels=80,
# input_concat_dim=80,
embed_dim=768,
# cond_token_dim=target_length,
depth=24,
num_heads=24,
project_cond_tokens=False,
transformer_type="continuous_transformer",
)
# self.diffusion = UNet1d(
# in_channels=80,
# channels=256,
# resnet_groups=16,
# kernel_multiplier_downsample=2,
# multipliers=[4, 4, 4, 5, 5],
# factors=[1, 2, 2, 4], # 输入长度不一致卷积缩短
# num_blocks=[2, 2, 2, 2],
# attentions=[1, 3, 3, 3, 3],
# attention_heads=16,
# attention_multiplier=4,
# use_nearest_upsample=False,
# use_skip_scale=True,
# use_context_time=True
# )
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
@torch.no_grad()
def forward(self, mu, mask, n_timesteps):
# pdb.set_trace()
mask = mask.squeeze(1)
# noise = torch.randn_like(mu).to(mu.device)
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
# extra_args = {"cross_attn_cond": mu, "cross_attn_cond_mask": mask, "mask": mask}
extra_args = {"mask": mask}
fakes = sample(self.diffusion, mu, n_timesteps, 0, **extra_args)
return fakes
def compute_loss(self, x0, mask, mu):
# pdb.set_trace()
t = self.rng.draw(x0.shape[0])[:, 0].to(x0.device)
alphas, sigmas = torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
alphas = alphas[:, None, None]
sigmas = sigmas[:, None, None]
noise = torch.randn_like(x0)
noised_inputs = x0 * alphas + noise * sigmas
targets = mu * alphas - x0 * sigmas
mask = mask.squeeze(1)
# mu_pad, mu_pad_mask = pad_and_create_mask(mu, target_length)
# output = self.diffusion(noised_inputs, t, cross_attn_cond=mu,
# cross_attn_cond_mask=mask, mask=mask, cfg_dropout_prob=0.1)
output = self.diffusion(noised_inputs, t, mask=mask, cfg_dropout_prob=0.1)
return self.mse_loss(output, targets, mask), output
def mse_loss(self, output, targets, mask):
mse_loss = F.mse_loss(output, targets, reduction='none')
if mask.ndim == 2 and mse_loss.ndim == 3:
mask = mask.unsqueeze(1)
if mask.shape[1] != mse_loss.shape[1]:
mask = mask.repeat(1, mse_loss.shape[1], 1)
mse_loss = mse_loss[mask]
mse_loss = mse_loss.mean()
return mse_loss
\ No newline at end of file
import pdb
from functools import reduce, partial
from packaging import version
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.cuda.amp import autocast
from typing import Callable, Literal
try:
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
except ImportError as e:
print(e)
print('flash_attn not installed, disabling Flash Attention')
flash_attn_kvpacked_func = None
flash_attn_func = None
try:
import natten
except ImportError:
natten = None
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
def or_reduce(masks):
head, *body = masks
for rest in body:
head = head | rest
return head
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.max_seq_len = max_seq_len
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return pos_emb
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
assert (dim % 2) == 0, 'dimension must be divisible by 2'
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent = False)
def forward(self, x, pos = None, seq_start_pos = None):
seq_len, device = x.shape[1], x.device
if pos is None:
pos = torch.arange(seq_len, device = device)
if seq_start_pos is not None:
pos = pos - seq_start_pos[..., None]
emb = einsum('i, j -> i j', pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
return emb * self.scale
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
use_xpos = False,
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
self.register_buffer('scale', scale)
def forward_from_seq_len(self, seq_len):
device = self.inv_freq.device
t = torch.arange(seq_len, device = device)
return self.forward(t)
@autocast(enabled = False)
def forward(self, t):
device = self.inv_freq.device
t = t.to(torch.float32)
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):
out_dtype = t.dtype
# cast to float32 if necessary for numerical stability
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
freqs, t = freqs.to(dtype), t.to(dtype)
freqs = freqs[-seq_len:, :]
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
return torch.cat((t, t_unrotated), dim = -1)
# norms
class LayerNorm(nn.Module):
def __init__(self, dim, bias=False, fix_scale=False):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
if fix_scale:
self.register_buffer("gamma", torch.ones(dim))
else:
self.gamma = nn.Parameter(torch.ones(dim))
if bias:
self.beta = nn.Parameter(torch.zeros(dim))
else:
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
# feedforward
class GLU(nn.Module):
def __init__(
self,
dim_in,
dim_out,
activation: Callable,
use_conv = False,
conv_kernel_size = 3,
):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
self.use_conv = use_conv
def forward(self, x):
if self.use_conv:
x = rearrange(x, 'b n d -> b d n')
x = self.proj(x)
x = rearrange(x, 'b d n -> b n d')
else:
x = self.proj(x)
x, gate = x.chunk(2, dim = -1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out = None,
mult = 4,
no_bias = False,
glu = True,
use_conv = False,
conv_kernel_size = 3,
zero_init_output = True,
):
super().__init__()
inner_dim = int(dim * mult)
# Default to SwiGLU
activation = nn.SiLU()
dim_out = dim if dim_out is None else dim_out
if glu:
linear_in = GLU(dim, inner_dim, activation)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
# init last linear layer to 0
if zero_init_output:
nn.init.zeros_(linear_out.weight)
if not no_bias:
nn.init.zeros_(linear_out.bias)
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
return self.ff(x)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
dim_context = None,
causal = False,
zero_init_output=True,
qk_norm: Literal['l2', 'ln', 'none'] = 'none',
natten_kernel_size = None
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.causal = causal
dim_kv = dim_context if dim_context is not None else dim
self.num_heads = dim // dim_heads
self.kv_heads = dim_kv // dim_heads
if dim_context is not None:
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
else:
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim, bias=False)
if zero_init_output:
nn.init.zeros_(self.to_out.weight)
self.qk_norm = qk_norm
if self.qk_norm == "ln":
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
# Using 1d neighborhood attention
self.natten_kernel_size = natten_kernel_size
if natten_kernel_size is not None:
return
self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
# pdb.set_trace()
self.use_fa_flash = False
self.sdp_kwargs = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
def flash_attn(
self,
q,
k,
v,
mask = None,
causal = None
):
batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
kv_heads = k.shape[1]
# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if heads != kv_heads:
# Repeat interleave kv_heads to match q_heads
heads_per_kv_head = heads // kv_heads
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
if k.ndim == 3:
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
if v.ndim == 3:
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
causal = self.causal if causal is None else causal
if q_len == 1 and causal:
causal = False
if mask is not None:
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)
# handle kv cache - this should be bypassable in updated flash attention 2
if k_len > q_len and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
if mask is None:
mask = ~causal_mask
else:
mask = mask & ~causal_mask
causal = False
# manually handle causal mask, if another mask was given
row_is_entirely_masked = None
if mask is not None and causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
mask = mask & ~causal_mask
# protect against an entire row being masked out
row_is_entirely_masked = ~mask.any(dim = -1)
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
causal = False
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
is_causal = causal
)
# for a row that is entirely masked out, should zero out the output of that row token
if row_is_entirely_masked is not None:
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out
def forward(
self,
x,
context = None,
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
kv_input = context if has_context else x
if hasattr(self, 'to_q'):
# Use separate linear projections for q and k/v
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
else:
# Use fused linear projection
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# Normalize q and k for cosine sim attention
if self.qk_norm == "l2":
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
elif self.qk_norm == "ln":
q = self.q_norm(q)
k = self.k_norm(k)
if rotary_pos_emb is not None and not has_context:
freqs, _ = rotary_pos_emb
q_dtype = q.dtype
k_dtype = k.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
freqs = freqs.to(torch.float32)
q = apply_rotary_pos_emb(q, freqs)
k = apply_rotary_pos_emb(k, freqs)
q = q.to(q_dtype)
k = k.to(k_dtype)
input_mask = context_mask
if input_mask is None and not has_context:
input_mask = mask
# determine masking
masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask)
# Other masks will be added here later
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal
if n == 1 and causal:
causal = False
if self.natten_kernel_size is not None:
if natten is None:
raise ImportError('natten not installed, please install natten to use neighborhood attention')
dtype_in = q.dtype
q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
if final_attn_mask is not None:
attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
attn = F.softmax(attn, dim=-1, dtype=torch.float32)
out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
# Prioritize Flash Attention 2
elif self.use_fa_flash:
# pdb.set_trace()
assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
# Flash Attention 2 requires FP16 inputs
fa_dtype_in = q.dtype
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
out = flash_attn_func(q, k, v, causal = causal)
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
# Fall back to PyTorch implementation
elif self.use_pt_flash:
out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
else:
# Fall back to custom implementation
if h != kv_h:
# Repeat interleave kv_heads to match q_heads
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
scale = 1. / (q.shape[-1] ** 0.5)
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
i, j, dtype = *dots.shape[-2:], dots.dtype
mask_value = -torch.finfo(dots.dtype).max
if final_attn_mask is not None:
dots = dots.masked_fill(~final_attn_mask, mask_value)
if causal:
causal_mask = self.create_causal_mask(i, j, device = device)
dots = dots.masked_fill(causal_mask, mask_value)
attn = F.softmax(dots, dim=-1, dtype=torch.float32)
attn = attn.type(dtype)
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
# merge heads
out = rearrange(out, ' b h n d -> b n (h d)')
# Communicate between heads
# with autocast(enabled = False):
# out_dtype = out.dtype
# out = out.to(torch.float32)
# out = self.to_out(out).to(out_dtype)
out = self.to_out(out)
if mask is not None:
mask = rearrange(mask, 'b n -> b n 1')
out = out.masked_fill(~mask, 0.)
return out
class ConformerModule(nn.Module):
def __init__(
self,
dim,
norm_kwargs = {},
):
super().__init__()
self.dim = dim
self.in_norm = LayerNorm(dim, **norm_kwargs)
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
self.glu = GLU(dim, dim, nn.SiLU())
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
self.swish = nn.SiLU()
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
x = self.in_norm(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.glu(x)
x = rearrange(x, 'b n d -> b d n')
x = self.depthwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.mid_norm(x)
x = self.swish(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_heads = 64,
cross_attend = False,
dim_context = None,
global_cond_dim = None,
causal = False,
zero_init_branch_outputs = True,
conformer = False,
layer_ix = -1,
remove_norms = False,
attn_kwargs = {},
ff_kwargs = {},
norm_kwargs = {}
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.cross_attend = cross_attend
self.dim_context = dim_context
self.causal = causal
self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
self.self_attn = Attention(
dim,
dim_heads = dim_heads,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
### 2. 主要是这边需要修改
if cross_attend:
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
self.cross_attn = Attention(
dim,
dim_heads = dim_heads,
dim_context=dim_context,
causal = causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
self.layer_ix = layer_ix
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
self.global_cond_dim = global_cond_dim
if global_cond_dim is not None:
self.to_scale_shift_gate = nn.Sequential(
nn.SiLU(),
nn.Linear(global_cond_dim, dim * 6, bias=False)
)
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
#nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
def forward(
self,
x,
context = None,
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
# self-attention with adaLN
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
# feedforward with adaLN
residual = x
x = self.ff_norm(x)
x = x * (1 + scale_ff) + shift_ff
x = self.ff(x)
x = x * torch.sigmoid(1 - gate_ff)
x = x + residual
else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
x = x + self.ff(self.ff_norm(x))
return x
class ContinuousTransformer(nn.Module):
def __init__(
self,
dim,
depth,
*,
dim_in = None,
dim_out = None,
dim_heads = 64,
cross_attend=False,
cond_token_dim=None,
global_cond_dim=None,
causal=False,
rotary_pos_emb=True,
zero_init_branch_outputs=True,
conformer=False,
use_sinusoidal_emb=False,
use_abs_pos_emb=False,
abs_pos_emb_max_length=10000,
**kwargs
):
super().__init__()
self.dim = dim
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
else:
self.rotary_pos_emb = None
self.use_sinusoidal_emb = use_sinusoidal_emb
if use_sinusoidal_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim)
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
for i in range(depth):
self.layers.append(
TransformerBlock(
dim,
dim_heads = dim_heads,
cross_attend = cross_attend,
dim_context = cond_token_dim,
global_cond_dim = global_cond_dim,
causal = causal,
zero_init_branch_outputs = zero_init_branch_outputs,
conformer=conformer,
layer_ix=i,
**kwargs
)
)
def forward(
self,
x,
mask = None,
prepend_embeds = None,
prepend_mask = None,
global_cond = None,
return_info = False,
**kwargs
):
batch, seq, device = *x.shape[:2], x.device
info = {
"hidden_states": [],
}
x = self.project_in(x)
if prepend_embeds is not None:
prepend_length, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
x = torch.cat((prepend_embeds, x), dim = -2)
if prepend_mask is not None or mask is not None:
mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
mask = torch.cat((prepend_mask, mask), dim = -1)
# Attention layers
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
else:
rotary_pos_emb = None
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
# Iterate over the transformer layers
for layer in self.layers:
#x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
# pdb.set_trace()
x = checkpoint(layer, x, mask=mask.bool(),rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
info["hidden_states"].append(x)
x = self.project_out(x)
if return_info:
return x, info
return x
import pdb
from functools import reduce, partial
from packaging import version
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.cuda.amp import autocast
from typing import Callable, Literal
try:
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
except ImportError as e:
print(e)
print('flash_attn not installed, disabling Flash Attention')
flash_attn_kvpacked_func = None
flash_attn_func = None
try:
import natten
except ImportError:
natten = None
def checkpoint(function, *args, **kwargs):
kwargs.setdefault("use_reentrant", False)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
def create_causal_mask(i, j, device):
return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1)
def or_reduce(masks):
head, *body = masks
for rest in body:
head = head | rest
return head
# positional embeddings
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim ** -0.5
self.max_seq_len = max_seq_len
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x, pos=None, seq_start_pos=None):
seq_len, device = x.shape[1], x.device
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
if pos is None:
pos = torch.arange(seq_len, device=device)
if seq_start_pos is not None:
pos = (pos - seq_start_pos[..., None]).clamp(min=0)
pos_emb = self.emb(pos)
pos_emb = pos_emb * self.scale
return pos_emb
class ScaledSinusoidalEmbedding(nn.Module):
def __init__(self, dim, theta=10000):
super().__init__()
assert (dim % 2) == 0, 'dimension must be divisible by 2'
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
half_dim = dim // 2
freq_seq = torch.arange(half_dim).float() / half_dim
inv_freq = theta ** -freq_seq
self.register_buffer('inv_freq', inv_freq, persistent=False)
def forward(self, x, pos=None, seq_start_pos=None):
seq_len, device = x.shape[1], x.device
if pos is None:
pos = torch.arange(seq_len, device=device)
if seq_start_pos is not None:
pos = pos - seq_start_pos[..., None]
emb = einsum('i, j -> i j', pos, self.inv_freq)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb * self.scale
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
use_xpos=False,
scale_base=512,
interpolation_factor=1.,
base=10000,
base_rescale_factor=1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
if not use_xpos:
self.register_buffer('scale', None)
return
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
self.scale_base = scale_base
self.register_buffer('scale', scale)
def forward_from_seq_len(self, seq_len):
device = self.inv_freq.device
t = torch.arange(seq_len, device=device)
return self.forward(t)
@autocast(enabled=False)
def forward(self, t):
device = self.inv_freq.device
t = t.to(torch.float32)
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim=-1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device=device) - (seq_len // 2)) / self.scale_base
scale = self.scale ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim=-1)
return freqs, scale
def rotate_half(x):
x = rearrange(x, '... (j d) -> ... j d', j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
@autocast(enabled=False)
def apply_rotary_pos_emb(t, freqs, scale=1):
out_dtype = t.dtype
# cast to float32 if necessary for numerical stability
dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
freqs, t = freqs.to(dtype), t.to(dtype)
freqs = freqs[-seq_len:, :]
if t.ndim == 4 and freqs.ndim == 3:
freqs = rearrange(freqs, 'b n d -> b 1 n d')
# partial rotary embeddings, Wang et al. GPT-J
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
return torch.cat((t, t_unrotated), dim=-1)
# norms
class LayerNorm(nn.Module):
def __init__(self, dim, bias=False, fix_scale=False):
"""
bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
"""
super().__init__()
if fix_scale:
self.register_buffer("gamma", torch.ones(dim))
else:
self.gamma = nn.Parameter(torch.ones(dim))
if bias:
self.beta = nn.Parameter(torch.zeros(dim))
else:
self.register_buffer("beta", torch.zeros(dim))
def forward(self, x):
return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
# feedforward
class GLU(nn.Module):
def __init__(
self,
dim_in,
dim_out,
activation: Callable,
use_conv=False,
conv_kernel_size=3,
):
super().__init__()
self.act = activation
self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size,
padding=(conv_kernel_size // 2))
self.use_conv = use_conv
def forward(self, x):
if self.use_conv:
x = rearrange(x, 'b n d -> b d n')
x = self.proj(x)
x = rearrange(x, 'b d n -> b n d')
else:
x = self.proj(x)
x, gate = x.chunk(2, dim=-1)
return x * self.act(gate)
class FeedForward(nn.Module):
def __init__(
self,
dim,
dim_out=None,
mult=4,
no_bias=False,
glu=True,
use_conv=False,
conv_kernel_size=3,
zero_init_output=True,
):
super().__init__()
inner_dim = int(dim * mult)
# Default to SwiGLU
activation = nn.SiLU()
dim_out = dim if dim_out is None else dim_out
if glu:
linear_in = GLU(dim, inner_dim, activation)
else:
linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
nn.Linear(dim, inner_dim, bias=not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim,
conv_kernel_size, padding=(
conv_kernel_size // 2), bias=not no_bias),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation
)
linear_out = nn.Linear(inner_dim, dim_out, bias=not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out,
conv_kernel_size,
padding=(
conv_kernel_size // 2),
bias=not no_bias)
# init last linear layer to 0
if zero_init_output:
nn.init.zeros_(linear_out.weight)
if not no_bias:
nn.init.zeros_(linear_out.bias)
self.ff = nn.Sequential(
linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
)
def forward(self, x):
return self.ff(x)
class Attention(nn.Module):
def __init__(
self,
dim,
dim_heads=64,
dim_context=None,
causal=False,
zero_init_output=True,
qk_norm: Literal['l2', 'ln', 'none'] = 'none',
natten_kernel_size=None
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.causal = causal
dim_kv = dim_context if dim_context is not None else dim
self.num_heads = dim // dim_heads
self.kv_heads = dim_kv // dim_heads
if dim_context is not None:
self.to_q = nn.Linear(dim, dim, bias=False)
self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
else:
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim, bias=False)
if zero_init_output:
nn.init.zeros_(self.to_out.weight)
self.qk_norm = qk_norm
if self.qk_norm == "ln":
self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
# Using 1d neighborhood attention
self.natten_kernel_size = natten_kernel_size
if natten_kernel_size is not None:
return
self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
# pdb.set_trace()
self.use_fa_flash = False
self.sdp_kwargs = dict(
enable_flash=True,
enable_math=True,
enable_mem_efficient=True
)
def flash_attn(
self,
q,
k,
v,
mask=None,
causal=None
):
batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
kv_heads = k.shape[1]
# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
if heads != kv_heads:
# Repeat interleave kv_heads to match q_heads
heads_per_kv_head = heads // kv_heads
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
if k.ndim == 3:
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
if v.ndim == 3:
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
causal = self.causal if causal is None else causal
if q_len == 1 and causal:
causal = False
if mask is not None:
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)
assert causal
# handle kv cache - this should be bypassable in updated flash attention 2
if k_len > q_len and causal:
causal_mask = create_causal_mask(q_len, k_len, device=device)
if mask is None:
mask = ~causal_mask
else:
mask = mask & ~causal_mask
causal = False
# manually handle causal mask, if another mask was given
row_is_entirely_masked = None
if mask is not None and causal:
causal_mask = create_causal_mask(q_len, k_len, device=device)
mask = mask & ~causal_mask
# protect against an entire row being masked out
row_is_entirely_masked = ~mask.any(dim=-1)
mask[..., 0] = mask[..., 0] | row_is_entirely_masked
causal = False
with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=mask,
is_causal=causal
)
# for a row that is entirely masked out, should zero out the output of that row token
if row_is_entirely_masked is not None:
out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
return out
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rotary_pos_emb=None,
causal=None
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
kv_input = context if has_context else x
if hasattr(self, 'to_q'):
# Use separate linear projections for q and k/v
q = self.to_q(x)
q = rearrange(q, 'b n (h d) -> b h n d', h=h)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=kv_h), (k, v))
else:
# Use fused linear projection
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
# Normalize q and k for cosine sim attention
if self.qk_norm == "l2":
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
elif self.qk_norm == "ln":
q = self.q_norm(q)
k = self.k_norm(k)
if rotary_pos_emb is not None and not has_context:
freqs, _ = rotary_pos_emb
q_dtype = q.dtype
k_dtype = k.dtype
q = q.to(torch.float32)
k = k.to(torch.float32)
freqs = freqs.to(torch.float32)
q = apply_rotary_pos_emb(q, freqs)
k = apply_rotary_pos_emb(k, freqs)
q = q.to(q_dtype)
k = k.to(k_dtype)
input_mask = context_mask
if input_mask is None and not has_context:
input_mask = mask
# determine masking
masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask)
# Other masks will be added here later
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal
if n == 1 and causal:
causal = False
if self.natten_kernel_size is not None:
if natten is None:
raise ImportError('natten not installed, please install natten to use neighborhood attention')
dtype_in = q.dtype
q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
attn = natten.functional.natten1dqk(q, k, kernel_size=self.natten_kernel_size, dilation=1)
if final_attn_mask is not None:
attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
attn = F.softmax(attn, dim=-1, dtype=torch.float32)
out = natten.functional.natten1dav(attn, v, kernel_size=self.natten_kernel_size, dilation=1).to(dtype_in)
# Prioritize Flash Attention 2
elif self.use_fa_flash:
assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
# Flash Attention 2 requires FP16 inputs
fa_dtype_in = q.dtype
q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
out = flash_attn_func(q, k, v, causal=causal)
out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
# Fall back to PyTorch implementation
elif self.use_pt_flash:
# causal=False
# final_attn_mask:[64, 1, 1, 348]
out = self.flash_attn(q, k, v, causal=True, mask=final_attn_mask)
else:
# Fall back to custom implementation
if h != kv_h:
# Repeat interleave kv_heads to match q_heads
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim=1), (k, v))
scale = 1. / (q.shape[-1] ** 0.5)
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
i, j, dtype = *dots.shape[-2:], dots.dtype
mask_value = -torch.finfo(dots.dtype).max
if final_attn_mask is not None:
dots = dots.masked_fill(~final_attn_mask, mask_value)
if causal:
causal_mask = create_causal_mask(i, j, device=device)
dots = dots.masked_fill(causal_mask, mask_value)
attn = F.softmax(dots, dim=-1, dtype=torch.float32)
attn = attn.type(dtype)
out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
# merge heads
out = rearrange(out, ' b h n d -> b n (h d)')
# Communicate between heads
# with autocast(enabled = False):
# out_dtype = out.dtype
# out = out.to(torch.float32)
# out = self.to_out(out).to(out_dtype)
out = self.to_out(out)
if mask is not None:
mask = rearrange(mask, 'b n -> b n 1')
out = out.masked_fill(~mask, 0.)
return out
class ConformerModule(nn.Module):
def __init__(
self,
dim,
norm_kwargs={},
):
super().__init__()
self.dim = dim
self.in_norm = LayerNorm(dim, **norm_kwargs)
self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
self.glu = GLU(dim, dim, nn.SiLU())
self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
self.mid_norm = LayerNorm(dim,
**norm_kwargs) # This is a batch norm in the original but I don't like batch norm
self.swish = nn.SiLU()
self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
def forward(self, x):
x = self.in_norm(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.glu(x)
x = rearrange(x, 'b n d -> b d n')
x = self.depthwise_conv(x)
x = rearrange(x, 'b d n -> b n d')
x = self.mid_norm(x)
x = self.swish(x)
x = rearrange(x, 'b n d -> b d n')
x = self.pointwise_conv_2(x)
x = rearrange(x, 'b d n -> b n d')
return x
class TransformerBlock(nn.Module):
def __init__(
self,
dim,
dim_heads=64,
cross_attend=False,
dim_context=None,
global_cond_dim=None,
causal=False,
zero_init_branch_outputs=True,
conformer=False,
layer_ix=-1,
remove_norms=False,
attn_kwargs={},
ff_kwargs={},
norm_kwargs={}
):
super().__init__()
self.dim = dim
self.dim_heads = dim_heads
self.cross_attend = cross_attend
self.dim_context = dim_context
self.causal = causal
self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
self.self_attn = Attention(
dim,
dim_heads=dim_heads,
causal=causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
### 2. 主要是这边需要修改
if cross_attend:
self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
self.cross_attn = Attention(
dim,
dim_heads=dim_heads,
dim_context=dim_context,
causal=causal,
zero_init_output=zero_init_branch_outputs,
**attn_kwargs
)
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
self.layer_ix = layer_ix
self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
self.global_cond_dim = global_cond_dim
if global_cond_dim is not None:
self.to_scale_shift_gate = nn.Sequential(
nn.SiLU(),
nn.Linear(global_cond_dim, dim * 6, bias=False)
)
nn.init.zeros_(self.to_scale_shift_gate[1].weight)
# nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
def forward(
self,
x,
context=None,
global_cond=None,
mask=None,
context_mask=None,
rotary_pos_emb=None
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(
global_cond).unsqueeze(1).chunk(6, dim=-1)
# self-attention with adaLN
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask=mask, rotary_pos_emb=rotary_pos_emb)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
# feedforward with adaLN
residual = x
x = self.ff_norm(x)
x = x * (1 + scale_ff) + shift_ff
x = self.ff(x)
x = x * torch.sigmoid(1 - gate_ff)
x = x + residual
else:
x = x + self.self_attn(self.pre_norm(x), mask=mask, rotary_pos_emb=rotary_pos_emb)
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context=context, context_mask=context_mask)
if self.conformer is not None:
x = x + self.conformer(x)
x = x + self.ff(self.ff_norm(x))
return x
class ContinuousTransformer(nn.Module):
def __init__(
self,
dim,
depth,
*,
dim_in=None,
dim_out=None,
dim_heads=64,
cross_attend=False,
cond_token_dim=None,
global_cond_dim=None,
causal=False,
rotary_pos_emb=True,
zero_init_branch_outputs=True,
conformer=False,
use_sinusoidal_emb=False,
use_abs_pos_emb=False,
abs_pos_emb_max_length=10000,
**kwargs
):
super().__init__()
self.dim = dim
self.depth = depth
self.causal = causal
self.layers = nn.ModuleList([])
self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
else:
self.rotary_pos_emb = None
self.use_sinusoidal_emb = use_sinusoidal_emb
if use_sinusoidal_emb:
self.pos_emb = ScaledSinusoidalEmbedding(dim)
self.use_abs_pos_emb = use_abs_pos_emb
if use_abs_pos_emb:
self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
for i in range(depth):
self.layers.append(
TransformerBlock(
dim,
dim_heads=dim_heads,
cross_attend=cross_attend,
dim_context=cond_token_dim,
global_cond_dim=global_cond_dim,
causal=causal,
zero_init_branch_outputs=zero_init_branch_outputs,
conformer=conformer,
layer_ix=i,
**kwargs
)
)
def forward(
self,
x,
mask=None,
prepend_embeds=None,
prepend_mask=None,
global_cond=None,
return_info=False,
**kwargs
):
batch, seq, device = *x.shape[:2], x.device
info = {
"hidden_states": [],
}
x = self.project_in(x)
if prepend_embeds is not None:
prepend_length, prepend_dim = prepend_embeds.shape[1:]
assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
x = torch.cat((prepend_embeds, x), dim=-2)
if prepend_mask is not None or mask is not None:
mask = mask if mask is not None else torch.ones((batch, seq), device=device, dtype=torch.bool)
prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length),
device=device, dtype=torch.bool)
mask = torch.cat((prepend_mask, mask), dim=-1)
# Attention layers
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
else:
rotary_pos_emb = None
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)
# Iterate over the transformer layers
mask = self.refine_mask(mask)
for layer in self.layers:
# x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
# pdb.set_trace()
x = checkpoint(layer, x, mask=mask.bool(), rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:
info["hidden_states"].append(x)
x = self.project_out(x)
if return_info:
return x, info
return x
def refine_mask(self, mask):
return mask
# pdb.set_trace()
# mask = 1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1)
# return mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment