Commit d2a2c140 authored by zhangqha's avatar zhangqha
Browse files

BladeDISC aasist code

parents
Pipeline #181 canceled with stages
"""
Main script that trains, validates, and evaluates
various models including AASIST.
AASIST
Copyright (c) 2021-present NAVER Corp.
MIT license
"""
import argparse
import json
import os
import sys
import warnings
from importlib import import_module
from pathlib import Path
from shutil import copy
from typing import Dict, List, Union
import torch
import torch.nn as nn
#add support for disc
import torch_blade
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchcontrib.optim import SWA
from data_utils import (Dataset_ASVspoof2019_train,
Dataset_ASVspoof2019_devNeval, genSpoof_list)
from evaluation import calculate_tDCF_EER
from utils import create_optimizer, seed_worker, set_seed, str_to_bool
## add support for disc
os.environ["TORCH_DISC_USE_TORCH_MLIR"] = "true"
os.environ["DISC_ENABLE_STITCH"] = "true"
warnings.filterwarnings("ignore", category=FutureWarning)
def main(args: argparse.Namespace) -> None:
"""
Main function.
Trains, validates, and evaluates the ASVspoof detection model.
"""
# load experiment configurations
with open(args.config, "r") as f_json:
config = json.loads(f_json.read())
model_config = config["model_config"]
optim_config = config["optim_config"]
optim_config["epochs"] = config["num_epochs"]
track = config["track"]
assert track in ["LA", "PA", "DF"], "Invalid track given"
if "eval_all_best" not in config:
config["eval_all_best"] = "True"
if "freq_aug" not in config:
config["freq_aug"] = "False"
# make experiment reproducible
set_seed(args.seed, config)
# define database related paths
output_dir = Path(args.output_dir)
prefix_2019 = "ASVspoof2019.{}".format(track)
database_path = Path(config["database_path"])
dev_trial_path = (database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.dev.trl.txt".format(
track, prefix_2019))
eval_trial_path = (
database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.eval.trl.txt".format(
track, prefix_2019))
# define model related paths
model_tag = "{}_{}_ep{}_bs{}".format(
track,
os.path.splitext(os.path.basename(args.config))[0],
config["num_epochs"], config["batch_size"])
if args.comment:
model_tag = model_tag + "_{}".format(args.comment)
model_tag = output_dir / model_tag
model_save_path = model_tag / "weights"
eval_score_path = model_tag / config["eval_output"]
writer = SummaryWriter(model_tag)
os.makedirs(model_save_path, exist_ok=True)
copy(args.config, model_tag / "config.conf")
# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: {}".format(device))
if device == "cpu":
raise ValueError("GPU not detected!")
# define model architecture
model = get_model(model_config, device)
# define dataloaders
trn_loader, dev_loader, eval_loader = get_loader(
database_path, args.seed, config)
# evaluates pretrained model and exit script
if args.eval:
model.load_state_dict(
torch.load(config["model_path"], map_location=device))
print("Model loaded : {}".format(config["model_path"]))
print("Start evaluation...")
produce_evaluation_file(eval_loader, model, device,
eval_score_path, eval_trial_path)
calculate_tDCF_EER(cm_scores_file=eval_score_path,
asv_score_file=database_path /
config["asv_score_path"],
output_file=model_tag / "t-DCF_EER.txt")
print("DONE.")
eval_eer, eval_tdcf = calculate_tDCF_EER(
cm_scores_file=eval_score_path,
asv_score_file=database_path / config["asv_score_path"],
output_file=model_tag/"loaded_model_t-DCF_EER.txt")
sys.exit(0)
# get optimizer and scheduler
optim_config["steps_per_epoch"] = len(trn_loader)
optimizer, scheduler = create_optimizer(model.parameters(), optim_config)
optimizer_swa = SWA(optimizer)
best_dev_eer = 1.
best_eval_eer = 100.
best_dev_tdcf = 0.05
best_eval_tdcf = 1.
n_swa_update = 0 # number of snapshots of model to use in SWA
f_log = open(model_tag / "metric_log.txt", "a")
f_log.write("=" * 5 + "\n")
# make directory for metric logging
metric_path = model_tag / "metrics"
os.makedirs(metric_path, exist_ok=True)
# Training
for epoch in range(config["num_epochs"]):
print("Start training epoch{:03d}".format(epoch))
running_loss = train_epoch(trn_loader, model, optimizer, device,
scheduler, config)
produce_evaluation_file(dev_loader, model, device,
metric_path/"dev_score.txt", dev_trial_path)
dev_eer, dev_tdcf = calculate_tDCF_EER(
cm_scores_file=metric_path/"dev_score.txt",
asv_score_file=database_path/config["asv_score_path"],
output_file=metric_path/"dev_t-DCF_EER_{}epo.txt".format(epoch),
printout=False)
print("DONE.\nLoss:{:.5f}, dev_eer: {:.3f}, dev_tdcf:{:.5f}".format(
running_loss, dev_eer, dev_tdcf))
writer.add_scalar("loss", running_loss, epoch)
writer.add_scalar("dev_eer", dev_eer, epoch)
writer.add_scalar("dev_tdcf", dev_tdcf, epoch)
best_dev_tdcf = min(dev_tdcf, best_dev_tdcf)
if best_dev_eer >= dev_eer:
print("best model find at epoch", epoch)
best_dev_eer = dev_eer
torch.save(model.state_dict(),
model_save_path / "epoch_{}_{:03.3f}.pth".format(epoch, dev_eer))
# do evaluation whenever best model is renewed
if str_to_bool(config["eval_all_best"]):
produce_evaluation_file(eval_loader, model, device,
eval_score_path, eval_trial_path)
eval_eer, eval_tdcf = calculate_tDCF_EER(
cm_scores_file=eval_score_path,
asv_score_file=database_path / config["asv_score_path"],
output_file=metric_path /
"t-DCF_EER_{:03d}epo.txt".format(epoch))
log_text = "epoch{:03d}, ".format(epoch)
if eval_eer < best_eval_eer:
log_text += "best eer, {:.4f}%".format(eval_eer)
best_eval_eer = eval_eer
if eval_tdcf < best_eval_tdcf:
log_text += "best tdcf, {:.4f}".format(eval_tdcf)
best_eval_tdcf = eval_tdcf
torch.save(model.state_dict(),
model_save_path / "best.pth")
if len(log_text) > 0:
print(log_text)
f_log.write(log_text + "\n")
print("Saving epoch {} for swa".format(epoch))
optimizer_swa.update_swa()
n_swa_update += 1
writer.add_scalar("best_dev_eer", best_dev_eer, epoch)
writer.add_scalar("best_dev_tdcf", best_dev_tdcf, epoch)
print("Start final evaluation")
epoch += 1
if n_swa_update > 0:
optimizer_swa.swap_swa_sgd()
optimizer_swa.bn_update(trn_loader, model, device=device)
produce_evaluation_file(eval_loader, model, device, eval_score_path,
eval_trial_path)
eval_eer, eval_tdcf = calculate_tDCF_EER(cm_scores_file=eval_score_path,
asv_score_file=database_path /
config["asv_score_path"],
output_file=model_tag / "t-DCF_EER.txt")
f_log = open(model_tag / "metric_log.txt", "a")
f_log.write("=" * 5 + "\n")
f_log.write("EER: {:.3f}, min t-DCF: {:.5f}".format(eval_eer, eval_tdcf))
f_log.close()
torch.save(model.state_dict(),
model_save_path / "swa.pth")
if eval_eer <= best_eval_eer:
best_eval_eer = eval_eer
if eval_tdcf <= best_eval_tdcf:
best_eval_tdcf = eval_tdcf
torch.save(model.state_dict(),
model_save_path / "best.pth")
print("Exp FIN. EER: {:.3f}, min t-DCF: {:.5f}".format(
best_eval_eer, best_eval_tdcf))
def get_model(model_config: Dict, device: torch.device):
"""Define DNN model architecture"""
module = import_module("models.{}".format(model_config["architecture"]))
_model = getattr(module, "Model")
model = _model(model_config).to(device)
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
print("no. model params:{}".format(nb_params))
return model
def get_loader(
database_path: str,
seed: int,
config: dict) -> List[torch.utils.data.DataLoader]:
"""Make PyTorch DataLoaders for train / developement / evaluation"""
track = config["track"]
prefix_2019 = "ASVspoof2019.{}".format(track)
trn_database_path = database_path / "ASVspoof2019_{}_train/".format(track)
dev_database_path = database_path / "ASVspoof2019_{}_dev/".format(track)
eval_database_path = database_path / "ASVspoof2019_{}_eval/".format(track)
trn_list_path = (database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.train.trn.txt".format(
track, prefix_2019))
dev_trial_path = (database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.dev.trl.txt".format(
track, prefix_2019))
eval_trial_path = (
database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.eval.trl.txt".format(
track, prefix_2019))
d_label_trn, file_train = genSpoof_list(dir_meta=trn_list_path,
is_train=True,
is_eval=False)
print("no. training files:", len(file_train))
train_set = Dataset_ASVspoof2019_train(list_IDs=file_train,
labels=d_label_trn,
base_dir=trn_database_path)
gen = torch.Generator()
gen.manual_seed(seed)
trn_loader = DataLoader(train_set,
batch_size=config["batch_size"],
shuffle=True,
drop_last=True,
pin_memory=True,
worker_init_fn=seed_worker,
generator=gen)
_, file_dev = genSpoof_list(dir_meta=dev_trial_path,
is_train=False,
is_eval=False)
print("no. validation files:", len(file_dev))
dev_set = Dataset_ASVspoof2019_devNeval(list_IDs=file_dev,
base_dir=dev_database_path)
dev_loader = DataLoader(dev_set,
batch_size=config["batch_size"],
shuffle=False,
drop_last=False,
pin_memory=True)
file_eval = genSpoof_list(dir_meta=eval_trial_path,
is_train=False,
is_eval=True)
eval_set = Dataset_ASVspoof2019_devNeval(list_IDs=file_eval,
base_dir=eval_database_path)
eval_loader = DataLoader(eval_set,
batch_size=config["batch_size"],
shuffle=False,
drop_last=False,
pin_memory=True)
return trn_loader, dev_loader, eval_loader
def produce_evaluation_file(
data_loader: DataLoader,
model,
device: torch.device,
save_path: str,
trial_path: str) -> None:
"""Perform evaluation and save the score to a file"""
model.eval()
with open(trial_path, "r") as f_trl:
trial_lines = f_trl.readlines()
fname_list = []
score_list = []
for batch_x, utt_id in data_loader:
batch_x = batch_x.to(device)
with torch.no_grad():
opt_model = torch_blade.optimize(model, allow_tracing=True, model_inputs=batch_x)
_, batch_out = opt_model(batch_x)
torch.jit.save(opt_model, 'aasist_dcu_opt.pt')
batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
fname_list.extend(utt_id)
score_list.extend(batch_score.tolist())
sys.exit(0)
assert len(trial_lines) == len(fname_list) == len(score_list)
with open(save_path, "w") as fh:
for fn, sco, trl in zip(fname_list, score_list, trial_lines):
_, utt_id, _, src, key = trl.strip().split(' ')
assert fn == utt_id
fh.write("{} {} {} {}\n".format(utt_id, src, key, sco))
print("Scores saved to {}".format(save_path))
def train_epoch(
trn_loader: DataLoader,
model,
optim: Union[torch.optim.SGD, torch.optim.Adam],
device: torch.device,
scheduler: torch.optim.lr_scheduler,
config: argparse.Namespace):
"""Train the model for one epoch"""
running_loss = 0
num_total = 0.0
ii = 0
model.train()
# set objective (Loss) functions
weight = torch.FloatTensor([0.1, 0.9]).to(device)
criterion = nn.CrossEntropyLoss(weight=weight)
for batch_x, batch_y in trn_loader:
batch_size = batch_x.size(0)
num_total += batch_size
ii += 1
batch_x = batch_x.to(device)
batch_y = batch_y.view(-1).type(torch.int64).to(device)
_, batch_out = model(batch_x, Freq_aug=str_to_bool(config["freq_aug"]))
batch_loss = criterion(batch_out, batch_y)
running_loss += batch_loss.item() * batch_size
optim.zero_grad()
batch_loss.backward()
optim.step()
if config["optim_config"]["scheduler"] in ["cosine", "keras_decay"]:
scheduler.step()
elif scheduler is None:
pass
else:
raise ValueError("scheduler error, got:{}".format(scheduler))
running_loss /= num_total
return running_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ASVspoof detection system")
parser.add_argument("--config",
dest="config",
type=str,
help="configuration file",
required=True)
parser.add_argument(
"--output_dir",
dest="output_dir",
type=str,
help="output directory for results",
default="./exp_result",
)
parser.add_argument("--seed",
type=int,
default=1234,
help="random seed (default: 1234)")
parser.add_argument(
"--eval",
action="store_true",
help="when this flag is given, evaluates given model and exit")
parser.add_argument("--comment",
type=str,
default=None,
help="comment to describe the saved model")
parser.add_argument("--eval_model_weights",
type=str,
default=None,
help="directory to the model weight file (can be also given in the config file)")
main(parser.parse_args())
"""
Main script that trains, validates, and evaluates
various models including AASIST.
AASIST
Copyright (c) 2021-present NAVER Corp.
MIT license
"""
import argparse
import json
import os
import sys
import warnings
from importlib import import_module
from pathlib import Path
from shutil import copy
from typing import Dict, List, Union
import torch
import torch.nn as nn
#add support for disc
import torch_blade
import timeit
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchcontrib.optim import SWA
from data_utils import (Dataset_ASVspoof2019_train,
Dataset_ASVspoof2019_devNeval, genSpoof_list)
from evaluation import calculate_tDCF_EER
from utils import create_optimizer, seed_worker, set_seed, str_to_bool
os.environ["TORCH_DISC_USE_TORCH_MLIR"] = "true"
os.environ["DISC_ENABLE_STITCH"] = "true"
warnings.filterwarnings("ignore", category=FutureWarning)
def main(args: argparse.Namespace) -> None:
"""
Main function.
Trains, validates, and evaluates the ASVspoof detection model.
"""
# load experiment configurations
with open(args.config, "r") as f_json:
config = json.loads(f_json.read())
model_config = config["model_config"]
optim_config = config["optim_config"]
optim_config["epochs"] = config["num_epochs"]
track = config["track"]
assert track in ["LA", "PA", "DF"], "Invalid track given"
if "eval_all_best" not in config:
config["eval_all_best"] = "True"
if "freq_aug" not in config:
config["freq_aug"] = "False"
# make experiment reproducible
set_seed(args.seed, config)
# define database related paths
output_dir = Path(args.output_dir)
prefix_2019 = "ASVspoof2019.{}".format(track)
database_path = Path(config["database_path"])
dev_trial_path = (database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.dev.trl.txt".format(
track, prefix_2019))
eval_trial_path = (
database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.eval.trl.txt".format(
track, prefix_2019))
# define model related paths
model_tag = "{}_{}_ep{}_bs{}".format(
track,
os.path.splitext(os.path.basename(args.config))[0],
config["num_epochs"], config["batch_size"])
if args.comment:
model_tag = model_tag + "_{}".format(args.comment)
model_tag = output_dir / model_tag
model_save_path = model_tag / "weights"
eval_score_path = model_tag / config["eval_output"]
writer = SummaryWriter(model_tag)
os.makedirs(model_save_path, exist_ok=True)
copy(args.config, model_tag / "config.conf")
# set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device: {}".format(device))
if device == "cpu":
raise ValueError("GPU not detected!")
# define model architecture
model = get_model(model_config, device)
# define dataloaders
trn_loader, dev_loader, eval_loader = get_loader(
database_path, args.seed, config)
# evaluates pretrained model and exit script
if args.eval:
model.load_state_dict(
torch.load(config["model_path"], map_location=device))
print("Model loaded : {}".format(config["model_path"]))
print("Start evaluation...")
produce_evaluation_file(eval_loader, model, device,
eval_score_path, eval_trial_path)
calculate_tDCF_EER(cm_scores_file=eval_score_path,
asv_score_file=database_path /
config["asv_score_path"],
output_file=model_tag / "t-DCF_EER.txt")
print("DONE.")
eval_eer, eval_tdcf = calculate_tDCF_EER(
cm_scores_file=eval_score_path,
asv_score_file=database_path / config["asv_score_path"],
output_file=model_tag/"loaded_model_t-DCF_EER.txt")
sys.exit(0)
# get optimizer and scheduler
optim_config["steps_per_epoch"] = len(trn_loader)
optimizer, scheduler = create_optimizer(model.parameters(), optim_config)
optimizer_swa = SWA(optimizer)
best_dev_eer = 1.
best_eval_eer = 100.
best_dev_tdcf = 0.05
best_eval_tdcf = 1.
n_swa_update = 0 # number of snapshots of model to use in SWA
f_log = open(model_tag / "metric_log.txt", "a")
f_log.write("=" * 5 + "\n")
# make directory for metric logging
metric_path = model_tag / "metrics"
os.makedirs(metric_path, exist_ok=True)
# Training
for epoch in range(config["num_epochs"]):
print("Start training epoch{:03d}".format(epoch))
running_loss = train_epoch(trn_loader, model, optimizer, device,
scheduler, config)
produce_evaluation_file(dev_loader, model, device,
metric_path/"dev_score.txt", dev_trial_path)
dev_eer, dev_tdcf = calculate_tDCF_EER(
cm_scores_file=metric_path/"dev_score.txt",
asv_score_file=database_path/config["asv_score_path"],
output_file=metric_path/"dev_t-DCF_EER_{}epo.txt".format(epoch),
printout=False)
print("DONE.\nLoss:{:.5f}, dev_eer: {:.3f}, dev_tdcf:{:.5f}".format(
running_loss, dev_eer, dev_tdcf))
writer.add_scalar("loss", running_loss, epoch)
writer.add_scalar("dev_eer", dev_eer, epoch)
writer.add_scalar("dev_tdcf", dev_tdcf, epoch)
best_dev_tdcf = min(dev_tdcf, best_dev_tdcf)
if best_dev_eer >= dev_eer:
print("best model find at epoch", epoch)
best_dev_eer = dev_eer
torch.save(model.state_dict(),
model_save_path / "epoch_{}_{:03.3f}.pth".format(epoch, dev_eer))
# do evaluation whenever best model is renewed
if str_to_bool(config["eval_all_best"]):
produce_evaluation_file(eval_loader, model, device,
eval_score_path, eval_trial_path)
eval_eer, eval_tdcf = calculate_tDCF_EER(
cm_scores_file=eval_score_path,
asv_score_file=database_path / config["asv_score_path"],
output_file=metric_path /
"t-DCF_EER_{:03d}epo.txt".format(epoch))
log_text = "epoch{:03d}, ".format(epoch)
if eval_eer < best_eval_eer:
log_text += "best eer, {:.4f}%".format(eval_eer)
best_eval_eer = eval_eer
if eval_tdcf < best_eval_tdcf:
log_text += "best tdcf, {:.4f}".format(eval_tdcf)
best_eval_tdcf = eval_tdcf
torch.save(model.state_dict(),
model_save_path / "best.pth")
if len(log_text) > 0:
print(log_text)
f_log.write(log_text + "\n")
print("Saving epoch {} for swa".format(epoch))
optimizer_swa.update_swa()
n_swa_update += 1
writer.add_scalar("best_dev_eer", best_dev_eer, epoch)
writer.add_scalar("best_dev_tdcf", best_dev_tdcf, epoch)
print("Start final evaluation")
epoch += 1
if n_swa_update > 0:
optimizer_swa.swap_swa_sgd()
optimizer_swa.bn_update(trn_loader, model, device=device)
produce_evaluation_file(eval_loader, model, device, eval_score_path,
eval_trial_path)
eval_eer, eval_tdcf = calculate_tDCF_EER(cm_scores_file=eval_score_path,
asv_score_file=database_path /
config["asv_score_path"],
output_file=model_tag / "t-DCF_EER.txt")
f_log = open(model_tag / "metric_log.txt", "a")
f_log.write("=" * 5 + "\n")
f_log.write("EER: {:.3f}, min t-DCF: {:.5f}".format(eval_eer, eval_tdcf))
f_log.close()
torch.save(model.state_dict(),
model_save_path / "swa.pth")
if eval_eer <= best_eval_eer:
best_eval_eer = eval_eer
if eval_tdcf <= best_eval_tdcf:
best_eval_tdcf = eval_tdcf
torch.save(model.state_dict(),
model_save_path / "best.pth")
print("Exp FIN. EER: {:.3f}, min t-DCF: {:.5f}".format(
best_eval_eer, best_eval_tdcf))
def get_model(model_config: Dict, device: torch.device):
"""Define DNN model architecture"""
module = import_module("models.{}".format(model_config["architecture"]))
_model = getattr(module, "Model")
model = _model(model_config).to(device)
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
print("no. model params:{}".format(nb_params))
return model
def get_loader(
database_path: str,
seed: int,
config: dict) -> List[torch.utils.data.DataLoader]:
"""Make PyTorch DataLoaders for train / developement / evaluation"""
track = config["track"]
prefix_2019 = "ASVspoof2019.{}".format(track)
trn_database_path = database_path / "ASVspoof2019_{}_train/".format(track)
dev_database_path = database_path / "ASVspoof2019_{}_dev/".format(track)
eval_database_path = database_path / "ASVspoof2019_{}_eval/".format(track)
trn_list_path = (database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.train.trn.txt".format(
track, prefix_2019))
dev_trial_path = (database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.dev.trl.txt".format(
track, prefix_2019))
eval_trial_path = (
database_path /
"ASVspoof2019_{}_cm_protocols/{}.cm.eval.trl.txt".format(
track, prefix_2019))
d_label_trn, file_train = genSpoof_list(dir_meta=trn_list_path,
is_train=True,
is_eval=False)
print("no. training files:", len(file_train))
train_set = Dataset_ASVspoof2019_train(list_IDs=file_train,
labels=d_label_trn,
base_dir=trn_database_path)
gen = torch.Generator()
gen.manual_seed(seed)
trn_loader = DataLoader(train_set,
batch_size=config["batch_size"],
shuffle=True,
drop_last=True,
pin_memory=True,
worker_init_fn=seed_worker,
generator=gen)
_, file_dev = genSpoof_list(dir_meta=dev_trial_path,
is_train=False,
is_eval=False)
print("no. validation files:", len(file_dev))
dev_set = Dataset_ASVspoof2019_devNeval(list_IDs=file_dev,
base_dir=dev_database_path)
dev_loader = DataLoader(dev_set,
batch_size=config["batch_size"],
shuffle=False,
drop_last=False,
pin_memory=True)
file_eval = genSpoof_list(dir_meta=eval_trial_path,
is_train=False,
is_eval=True)
eval_set = Dataset_ASVspoof2019_devNeval(list_IDs=file_eval,
base_dir=eval_database_path)
eval_loader = DataLoader(eval_set,
batch_size=config["batch_size"],
shuffle=False,
drop_last=False,
pin_memory=True)
return trn_loader, dev_loader, eval_loader
def produce_evaluation_file(
data_loader: DataLoader,
model,
device: torch.device,
save_path: str,
trial_path: str) -> None:
"""Perform evaluation and save the score to a file"""
model = torch.jit.load('./aasist_dcu_opt.pt').cuda().eval()
with open(trial_path, "r") as f_trl:
trial_lines = f_trl.readlines()
fname_list = []
score_list = []
for batch_x, utt_id in data_loader:
batch_x = batch_x.to(device)
with torch.no_grad():
start_time = timeit.default_timer()
_, batch_out = model(batch_x)
print("pytorch {} Seconds needed for single threaded execution".format(timeit.default_timer()-start_time))
batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
fname_list.extend(utt_id)
score_list.extend(batch_score.tolist())
assert len(trial_lines) == len(fname_list) == len(score_list)
with open(save_path, "w") as fh:
for fn, sco, trl in zip(fname_list, score_list, trial_lines):
_, utt_id, _, src, key = trl.strip().split(' ')
assert fn == utt_id
fh.write("{} {} {} {}\n".format(utt_id, src, key, sco))
print("Scores saved to {}".format(save_path))
def train_epoch(
trn_loader: DataLoader,
model,
optim: Union[torch.optim.SGD, torch.optim.Adam],
device: torch.device,
scheduler: torch.optim.lr_scheduler,
config: argparse.Namespace):
"""Train the model for one epoch"""
running_loss = 0
num_total = 0.0
ii = 0
model.train()
# set objective (Loss) functions
weight = torch.FloatTensor([0.1, 0.9]).to(device)
criterion = nn.CrossEntropyLoss(weight=weight)
for batch_x, batch_y in trn_loader:
batch_size = batch_x.size(0)
num_total += batch_size
ii += 1
batch_x = batch_x.to(device)
batch_y = batch_y.view(-1).type(torch.int64).to(device)
_, batch_out = model(batch_x, Freq_aug=str_to_bool(config["freq_aug"]))
batch_loss = criterion(batch_out, batch_y)
running_loss += batch_loss.item() * batch_size
optim.zero_grad()
batch_loss.backward()
optim.step()
if config["optim_config"]["scheduler"] in ["cosine", "keras_decay"]:
scheduler.step()
elif scheduler is None:
pass
else:
raise ValueError("scheduler error, got:{}".format(scheduler))
running_loss /= num_total
return running_loss
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="ASVspoof detection system")
parser.add_argument("--config",
dest="config",
type=str,
help="configuration file",
required=True)
parser.add_argument(
"--output_dir",
dest="output_dir",
type=str,
help="output directory for results",
default="./exp_result",
)
parser.add_argument("--seed",
type=int,
default=1234,
help="random seed (default: 1234)")
parser.add_argument(
"--eval",
action="store_true",
help="when this flag is given, evaluates given model and exit")
parser.add_argument("--comment",
type=str,
default=None,
help="comment to describe the saved model")
parser.add_argument("--eval_model_weights",
type=str,
default=None,
help="directory to the model weight file (can be also given in the config file)")
main(parser.parse_args())
"""
AASIST
Copyright (c) 2021-present NAVER Corp.
MIT license
"""
import random
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class GraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, **kwargs):
super().__init__()
# attention map
self.att_proj = nn.Linear(in_dim, out_dim)
self.att_weight = self._init_new_params(out_dim, 1)
# project
self.proj_with_att = nn.Linear(in_dim, out_dim)
self.proj_without_att = nn.Linear(in_dim, out_dim)
# batch norm
self.bn = nn.BatchNorm1d(out_dim)
# dropout for inputs
self.input_drop = nn.Dropout(p=0.2)
# activate
self.act = nn.SELU(inplace=True)
# temperature
self.temp = 1.
if "temperature" in kwargs:
self.temp = kwargs["temperature"]
def forward(self, x):
'''
x :(#bs, #node, #dim)
'''
# apply input dropout
x = self.input_drop(x)
# derive attention map
att_map = self._derive_att_map(x)
# projection
x = self._project(x, att_map)
# apply batch norm
x = self._apply_BN(x)
x = self.act(x)
return x
def _pairwise_mul_nodes(self, x):
'''
Calculates pairwise multiplication of nodes.
- for attention map
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, #dim)
'''
nb_nodes = x.size(1)
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
x_mirror = x.transpose(1, 2)
return x * x_mirror
def _derive_att_map(self, x):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = self._pairwise_mul_nodes(x)
# size: (#bs, #node, #node, #dim_out)
att_map = torch.tanh(self.att_proj(att_map))
# size: (#bs, #node, #node, 1)
att_map = torch.matmul(att_map, self.att_weight)
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _project(self, x, att_map):
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
x2 = self.proj_without_att(x)
return x1 + x2
def _apply_BN(self, x):
org_size = x.size()
x = x.view(-1, org_size[-1])
x = self.bn(x)
x = x.view(org_size)
return x
def _init_new_params(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
class HtrgGraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, **kwargs):
super().__init__()
self.proj_type1 = nn.Linear(in_dim, in_dim)
self.proj_type2 = nn.Linear(in_dim, in_dim)
# attention map
self.att_proj = nn.Linear(in_dim, out_dim)
self.att_projM = nn.Linear(in_dim, out_dim)
self.att_weight11 = self._init_new_params(out_dim, 1)
self.att_weight22 = self._init_new_params(out_dim, 1)
self.att_weight12 = self._init_new_params(out_dim, 1)
self.att_weightM = self._init_new_params(out_dim, 1)
# project
self.proj_with_att = nn.Linear(in_dim, out_dim)
self.proj_without_att = nn.Linear(in_dim, out_dim)
self.proj_with_attM = nn.Linear(in_dim, out_dim)
self.proj_without_attM = nn.Linear(in_dim, out_dim)
# batch norm
self.bn = nn.BatchNorm1d(out_dim)
# dropout for inputs
self.input_drop = nn.Dropout(p=0.2)
# activate
self.act = nn.SELU(inplace=True)
# temperature
self.temp = 1.
if "temperature" in kwargs:
self.temp = kwargs["temperature"]
def forward(self, x1, x2, master=None):
'''
x1 :(#bs, #node, #dim)
x2 :(#bs, #node, #dim)
'''
num_type1 = x1.size(1)
num_type2 = x2.size(1)
x1 = self.proj_type1(x1)
x2 = self.proj_type2(x2)
x = torch.cat([x1, x2], dim=1)
if master is None:
master = torch.mean(x, dim=1, keepdim=True)
# apply input dropout
x = self.input_drop(x)
# derive attention map
att_map = self._derive_att_map(x, num_type1, num_type2)
# directional edge for master node
master = self._update_master(x, master)
# projection
x = self._project(x, att_map)
# apply batch norm
x = self._apply_BN(x)
x = self.act(x)
x1 = x.narrow(1, 0, num_type1)
x2 = x.narrow(1, num_type1, num_type2)
return x1, x2, master
def _update_master(self, x, master):
att_map = self._derive_att_map_master(x, master)
master = self._project_master(x, master, att_map)
return master
def _pairwise_mul_nodes(self, x):
'''
Calculates pairwise multiplication of nodes.
- for attention map
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, #dim)
'''
nb_nodes = x.size(1)
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
x_mirror = x.transpose(1, 2)
return x * x_mirror
def _derive_att_map_master(self, x, master):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = x * master
att_map = torch.tanh(self.att_projM(att_map))
att_map = torch.matmul(att_map, self.att_weightM)
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _derive_att_map(self, x, num_type1, num_type2):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = self._pairwise_mul_nodes(x)
# size: (#bs, #node, #node, #dim_out)
att_map = torch.tanh(self.att_proj(att_map))
# size: (#bs, #node, #node, 1)
att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
att_board[:, :num_type1, :num_type1, :] = torch.matmul(
att_map[:, :num_type1, :num_type1, :], self.att_weight11)
att_board[:, num_type1:, num_type1:, :] = torch.matmul(
att_map[:, num_type1:, num_type1:, :], self.att_weight22)
att_board[:, :num_type1, num_type1:, :] = torch.matmul(
att_map[:, :num_type1, num_type1:, :], self.att_weight12)
att_board[:, num_type1:, :num_type1, :] = torch.matmul(
att_map[:, num_type1:, :num_type1, :], self.att_weight12)
att_map = att_board
# att_map = torch.matmul(att_map, self.att_weight12)
# apply temperature
att_map = att_map / self.temp
att_map = F.softmax(att_map, dim=-2)
return att_map
def _project(self, x, att_map):
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
x2 = self.proj_without_att(x)
return x1 + x2
def _project_master(self, x, master, att_map):
x1 = self.proj_with_attM(torch.matmul(
att_map.squeeze(-1).unsqueeze(1), x))
x2 = self.proj_without_attM(master)
return x1 + x2
def _apply_BN(self, x):
org_size = x.size()
x = x.view(-1, org_size[-1])
x = self.bn(x)
x = x.view(org_size)
return x
def _init_new_params(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
class GraphPool(nn.Module):
def __init__(self, k: float, in_dim: int, p: Union[float, int]):
super().__init__()
self.k = k
self.sigmoid = nn.Sigmoid()
self.proj = nn.Linear(in_dim, 1)
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
self.in_dim = in_dim
def forward(self, h):
Z = self.drop(h)
weights = self.proj(Z)
scores = self.sigmoid(weights)
new_h = self.top_k_graph(scores, h, self.k)
return new_h
def top_k_graph(self, scores, h, k):
"""
args
=====
scores: attention-based weights (#bs, #node, 1)
h: graph data (#bs, #node, #dim)
k: ratio of remaining nodes, (float)
returns
=====
h: graph pool applied data (#bs, #node', #dim)
"""
_, n_nodes, n_feat = h.size()
n_nodes = max(int(n_nodes * k), 1)
_, idx = torch.topk(scores, n_nodes, dim=1)
idx = idx.expand(-1, -1, n_feat)
h = h * scores
h = torch.gather(h, 1, idx)
return h
class CONV(nn.Module):
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10**(mel / 2595) - 1)
def __init__(self,
out_channels,
kernel_size,
sample_rate=16000,
in_channels=1,
stride=1,
padding=0,
dilation=1,
bias=False,
groups=1,
mask=False):
super().__init__()
if in_channels != 1:
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
in_channels)
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
self.sample_rate = sample_rate
# Forcing the filters to be odd (i.e, perfectly symmetrics)
if kernel_size % 2 == 0:
self.kernel_size = self.kernel_size + 1
self.stride = stride
self.padding = padding
self.dilation = dilation
self.mask = mask
if bias:
raise ValueError('SincConv does not support bias.')
if groups > 1:
raise ValueError('SincConv does not support groups.')
NFFT = 512
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
fmel = self.to_mel(f)
fmelmax = np.max(fmel)
fmelmin = np.min(fmel)
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
filbandwidthsf = self.to_hz(filbandwidthsmel)
self.mel = filbandwidthsf
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
(self.kernel_size - 1) / 2 + 1)
self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
for i in range(len(self.mel) - 1):
fmin = self.mel[i]
fmax = self.mel[i + 1]
hHigh = (2*fmax/self.sample_rate) * \
np.sinc(2*fmax*self.hsupp/self.sample_rate)
hLow = (2*fmin/self.sample_rate) * \
np.sinc(2*fmin*self.hsupp/self.sample_rate)
hideal = hHigh - hLow
self.band_pass[i, :] = Tensor(np.hamming(
self.kernel_size)) * Tensor(hideal)
def forward(self, x, mask=False):
band_pass_filter = self.band_pass.clone().to(x.device)
if mask:
A = np.random.uniform(0, 20)
A = int(A)
A0 = random.randint(0, band_pass_filter.shape[0] - A)
band_pass_filter[A0:A0 + A, :] = 0
else:
band_pass_filter = band_pass_filter
self.filters = (band_pass_filter).view(self.out_channels, 1,
self.kernel_size)
return F.conv1d(x,
self.filters,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=None,
groups=1)
class Residual_block(nn.Module):
def __init__(self, nb_filts, first=False):
super().__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(1, 1),
stride=1)
self.selu = nn.SELU(inplace=True)
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(0, 1),
stride=1)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=(0, 1),
kernel_size=(1, 3),
stride=1)
else:
self.downsample = False
self.mp = nn.MaxPool2d((1, 3)) # self.mp = nn.MaxPool2d((1,4))
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.selu(out)
else:
out = x
out = self.conv1(x)
# print('out',out.shape)
out = self.bn2(out)
out = self.selu(out)
# print('out',out.shape)
out = self.conv2(out)
#print('conv2 out',out.shape)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
return out
class Model(nn.Module):
def __init__(self, d_args):
super().__init__()
self.d_args = d_args
filts = d_args["filts"]
gat_dims = d_args["gat_dims"]
pool_ratios = d_args["pool_ratios"]
temperatures = d_args["temperatures"]
self.conv_time = CONV(out_channels=filts[0],
kernel_size=d_args["first_conv"],
in_channels=1)
self.first_bn = nn.BatchNorm2d(num_features=1)
self.drop = nn.Dropout(0.5, inplace=True)
self.drop_way = nn.Dropout(0.2, inplace=True)
self.selu = nn.SELU(inplace=True)
self.encoder = nn.Sequential(
nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
nn.Sequential(Residual_block(nb_filts=filts[2])),
nn.Sequential(Residual_block(nb_filts=filts[3])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])))
self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[0])
self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
gat_dims[0],
temperature=temperatures[1])
self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
gat_dims[0], gat_dims[1], temperature=temperatures[2])
self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
gat_dims[1], gat_dims[1], temperature=temperatures[2])
self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
self.out_layer = nn.Linear(5 * gat_dims[1], 2)
def forward(self, x, Freq_aug=False):
x = x.unsqueeze(1)
x = self.conv_time(x, mask=Freq_aug)
x = x.unsqueeze(dim=1)
x = F.max_pool2d(torch.abs(x), (3, 3))
x = self.first_bn(x)
x = self.selu(x)
# get embeddings using encoder
# (#bs, #filt, #spec, #seq)
e = self.encoder(x)
# spectral GAT (GAT-S)
e_S, _ = torch.max(torch.abs(e), dim=3) # max along time
e_S = e_S.transpose(1, 2) + self.pos_S
gat_S = self.GAT_layer_S(e_S)
out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
# temporal GAT (GAT-T)
e_T, _ = torch.max(torch.abs(e), dim=2) # max along freq
e_T = e_T.transpose(1, 2)
gat_T = self.GAT_layer_T(e_T)
out_T = self.pool_T(gat_T)
# learnable master node
master1 = self.master1.expand(x.size(0), -1, -1)
master2 = self.master2.expand(x.size(0), -1, -1)
# inference 1
out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
out_T, out_S, master=self.master1)
out_S1 = self.pool_hS1(out_S1)
out_T1 = self.pool_hT1(out_T1)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
out_T1, out_S1, master=master1)
out_T1 = out_T1 + out_T_aug
out_S1 = out_S1 + out_S_aug
master1 = master1 + master_aug
# inference 2
out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
out_T, out_S, master=self.master2)
out_S2 = self.pool_hS2(out_S2)
out_T2 = self.pool_hT2(out_T2)
out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
out_T2, out_S2, master=master2)
out_T2 = out_T2 + out_T_aug
out_S2 = out_S2 + out_S_aug
master2 = master2 + master_aug
out_T1 = self.drop_way(out_T1)
out_T2 = self.drop_way(out_T2)
out_S1 = self.drop_way(out_S1)
out_S2 = self.drop_way(out_S2)
master1 = self.drop_way(master1)
master2 = self.drop_way(master2)
out_T = torch.max(out_T1, out_T2)
out_S = torch.max(out_S1, out_S2)
master = torch.max(master1, master2)
T_max, _ = torch.max(torch.abs(out_T), dim=1)
T_avg = torch.mean(out_T, dim=1)
S_max, _ = torch.max(torch.abs(out_S), dim=1)
S_avg = torch.mean(out_S, dim=1)
last_hidden = torch.cat(
[T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
last_hidden = self.drop(last_hidden)
output = self.out_layer(last_hidden)
return last_hidden, output
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import numpy as np
from torch.utils import data
from torch.nn.parameter import Parameter
___author__ = "Hemlata Tak"
__email__ = "tak@eurecom.fr"
class SincConv(nn.Module):
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10**(mel / 2595) - 1)
def __init__(
self,
#device,
out_channels,
kernel_size,
in_channels=1,
sample_rate=16000,
stride=1,
padding=0,
dilation=1,
bias=False,
groups=1,
):
super().__init__()
if in_channels != 1:
msg = (
"SincConv only support one input channel (here, in_channels = {%i})"
% (in_channels))
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
self.sample_rate = sample_rate
# Forcing the filters to be odd (i.e, perfectly symmetrics)
if kernel_size % 2 == 0:
self.kernel_size = self.kernel_size + 1
#self.device = device
self.stride = stride
self.padding = padding
self.dilation = dilation
if bias:
raise ValueError("SincConv does not support bias.")
if groups > 1:
raise ValueError("SincConv does not support groups.")
# initialize filterbanks using Mel scale
NFFT = 512
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
fmel = self.to_mel(f) # Hz to mel conversion
fmelmax = np.max(fmel)
fmelmin = np.min(fmel)
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
filbandwidthsf = self.to_hz(filbandwidthsmel) # Mel to Hz conversion
self.mel = filbandwidthsf
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
(self.kernel_size - 1) / 2 + 1)
self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
def forward(self, x):
for i in range(len(self.mel) - 1):
fmin = self.mel[i]
fmax = self.mel[i + 1]
hHigh = (2 * fmax / self.sample_rate) * np.sinc(
2 * fmax * self.hsupp / self.sample_rate)
hLow = (2 * fmin / self.sample_rate) * np.sinc(
2 * fmin * self.hsupp / self.sample_rate)
hideal = hHigh - hLow
self.band_pass[i, :] = Tensor(np.hamming(
self.kernel_size)) * Tensor(hideal)
band_pass_filter = self.band_pass.to(x.device)
self.filters = (band_pass_filter).view(self.out_channels, 1,
self.kernel_size)
return F.conv1d(
x,
self.filters,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=None,
groups=1,
)
class Residual_block(nn.Module):
def __init__(self, nb_filts, first=False):
super(Residual_block, self).__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm1d(num_features=nb_filts[0])
self.lrelu = nn.LeakyReLU(negative_slope=0.3)
self.conv1 = nn.Conv1d(
in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=3,
padding=1,
stride=1,
)
self.bn2 = nn.BatchNorm1d(num_features=nb_filts[1])
self.conv2 = nn.Conv1d(
in_channels=nb_filts[1],
out_channels=nb_filts[1],
padding=1,
kernel_size=3,
stride=1,
)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv1d(
in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=0,
kernel_size=1,
stride=1,
)
else:
self.downsample = False
self.mp = nn.MaxPool1d(3)
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.lrelu(out)
else:
out = x
out = self.conv1(x)
out = self.bn2(out)
out = self.lrelu(out)
out = self.conv2(out)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
return out
class Model(nn.Module):
#def __init__(self, d_args, device):
def __init__(self, d_args):
super().__init__()
#self.device = device
self.Sinc_conv = SincConv(
#device=self.device,
out_channels=d_args["filts"][0],
kernel_size=d_args["first_conv"],
in_channels=d_args["in_channels"],
)
self.first_bn = nn.BatchNorm1d(num_features=d_args["filts"][0])
self.selu = nn.SELU(inplace=True)
self.block0 = nn.Sequential(
Residual_block(nb_filts=d_args["filts"][1], first=True))
self.block1 = nn.Sequential(
Residual_block(nb_filts=d_args["filts"][1]))
self.block2 = nn.Sequential(
Residual_block(nb_filts=d_args["filts"][2]))
d_args["filts"][2][0] = d_args["filts"][2][1]
self.block3 = nn.Sequential(
Residual_block(nb_filts=d_args["filts"][2]))
self.block4 = nn.Sequential(
Residual_block(nb_filts=d_args["filts"][2]))
self.block5 = nn.Sequential(
Residual_block(nb_filts=d_args["filts"][2]))
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc_attention0 = self._make_attention_fc(
in_features=d_args["filts"][1][-1],
l_out_features=d_args["filts"][1][-1])
self.fc_attention1 = self._make_attention_fc(
in_features=d_args["filts"][1][-1],
l_out_features=d_args["filts"][1][-1])
self.fc_attention2 = self._make_attention_fc(
in_features=d_args["filts"][2][-1],
l_out_features=d_args["filts"][2][-1])
self.fc_attention3 = self._make_attention_fc(
in_features=d_args["filts"][2][-1],
l_out_features=d_args["filts"][2][-1])
self.fc_attention4 = self._make_attention_fc(
in_features=d_args["filts"][2][-1],
l_out_features=d_args["filts"][2][-1])
self.fc_attention5 = self._make_attention_fc(
in_features=d_args["filts"][2][-1],
l_out_features=d_args["filts"][2][-1])
self.bn_before_gru = nn.BatchNorm1d(
num_features=d_args["filts"][2][-1])
self.gru = nn.GRU(
input_size=d_args["filts"][2][-1],
hidden_size=d_args["gru_node"],
num_layers=d_args["nb_gru_layer"],
batch_first=True,
)
self.fc1_gru = nn.Linear(in_features=d_args["gru_node"],
out_features=d_args["nb_fc_node"])
self.fc2_gru = nn.Linear(
in_features=d_args["nb_fc_node"],
out_features=d_args["nb_classes"],
bias=True,
)
self.sig = nn.Sigmoid()
self.logsoftmax = nn.LogSoftmax(dim=1)
def forward(self, x, Freq_aug=None):
nb_samp = x.shape[0]
len_seq = x.shape[1]
x = x.view(nb_samp, 1, len_seq)
x = self.Sinc_conv(x)
x = F.max_pool1d(torch.abs(x), 3)
x = self.first_bn(x)
x = self.selu(x)
x0 = self.block0(x)
y0 = self.avgpool(x0).view(x0.size(0),
-1) # torch.Size([batch, filter])
y0 = self.fc_attention0(y0)
y0 = self.sig(y0).view(y0.size(0), y0.size(1),
-1) # torch.Size([batch, filter, 1])
x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
x1 = self.block1(x)
y1 = self.avgpool(x1).view(x1.size(0),
-1) # torch.Size([batch, filter])
y1 = self.fc_attention1(y1)
y1 = self.sig(y1).view(y1.size(0), y1.size(1),
-1) # torch.Size([batch, filter, 1])
x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
x2 = self.block2(x)
y2 = self.avgpool(x2).view(x2.size(0),
-1) # torch.Size([batch, filter])
y2 = self.fc_attention2(y2)
y2 = self.sig(y2).view(y2.size(0), y2.size(1),
-1) # torch.Size([batch, filter, 1])
x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
x3 = self.block3(x)
y3 = self.avgpool(x3).view(x3.size(0),
-1) # torch.Size([batch, filter])
y3 = self.fc_attention3(y3)
y3 = self.sig(y3).view(y3.size(0), y3.size(1),
-1) # torch.Size([batch, filter, 1])
x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
x4 = self.block4(x)
y4 = self.avgpool(x4).view(x4.size(0),
-1) # torch.Size([batch, filter])
y4 = self.fc_attention4(y4)
y4 = self.sig(y4).view(y4.size(0), y4.size(1),
-1) # torch.Size([batch, filter, 1])
x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
x5 = self.block5(x)
y5 = self.avgpool(x5).view(x5.size(0),
-1) # torch.Size([batch, filter])
y5 = self.fc_attention5(y5)
y5 = self.sig(y5).view(y5.size(0), y5.size(1),
-1) # torch.Size([batch, filter, 1])
x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
x = self.bn_before_gru(x)
x = self.selu(x)
x = x.permute(0, 2, 1) # (batch, filt, time) >> (batch, time, filt)
self.gru.flatten_parameters()
x, _ = self.gru(x)
x = x[:, -1, :]
last_hidden = self.fc1_gru(x)
x = self.fc2_gru(last_hidden)
output = self.logsoftmax(x)
return last_hidden, output
def _make_attention_fc(self, in_features, l_out_features):
l_fc = []
l_fc.append(
nn.Linear(in_features=in_features, out_features=l_out_features))
return nn.Sequential(*l_fc)
def _make_layer(self, nb_blocks, nb_filts, first=False):
layers = []
# def __init__(self, nb_filts, first = False):
for i in range(nb_blocks):
first = first if i == 0 else False
layers.append(Residual_block(nb_filts=nb_filts, first=first))
if i == 0:
nb_filts[0] = nb_filts[1]
return nn.Sequential(*layers)
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class GraphAttentionLayer(nn.Module):
def __init__(self, in_dim, out_dim, **kwargs):
super().__init__()
# attention map
self.att_proj = nn.Linear(in_dim, out_dim)
self.att_weight = self._init_new_params(out_dim, 1)
# project
self.proj_with_att = nn.Linear(in_dim, out_dim)
self.proj_without_att = nn.Linear(in_dim, out_dim)
# batch norm
self.bn = nn.BatchNorm1d(out_dim)
# dropout for inputs
self.input_drop = nn.Dropout(p=0.2)
# activate
self.act = nn.SELU(inplace=True)
def forward(self, x):
'''
x :(#bs, #node, #dim)
'''
# apply input dropout
x = self.input_drop(x)
# derive attention map
att_map = self._derive_att_map(x)
# projection
x = self._project(x, att_map)
# apply batch norm
x = self._apply_BN(x)
x = self.act(x)
return x
def _pairwise_mul_nodes(self, x):
'''
Calculates pairwise multiplication of nodes.
- for attention map
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, #dim)
'''
nb_nodes = x.size(1)
x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
x_mirror = x.transpose(1, 2)
return x * x_mirror
def _derive_att_map(self, x):
'''
x :(#bs, #node, #dim)
out_shape :(#bs, #node, #node, 1)
'''
att_map = self._pairwise_mul_nodes(x)
# size: (#bs, #node, #node, #dim_out)
att_map = torch.tanh(self.att_proj(att_map))
# size: (#bs, #node, #node, 1)
att_map = torch.matmul(att_map, self.att_weight)
att_map = F.softmax(att_map, dim=-2)
return att_map
def _project(self, x, att_map):
x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
x2 = self.proj_without_att(x)
return x1 + x2
def _apply_BN(self, x):
org_size = x.size()
x = x.view(-1, org_size[-1])
x = self.bn(x)
x = x.view(org_size)
return x
def _init_new_params(self, *size):
out = nn.Parameter(torch.FloatTensor(*size))
nn.init.xavier_normal_(out)
return out
class GraphPool(nn.Module):
def __init__(self, k, in_dim, p):
super().__init__()
self.k = k
self.sigmoid = nn.Sigmoid()
self.proj = nn.Linear(in_dim, 1)
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
self.in_dim = in_dim
def forward(self, h):
Z = self.drop(h)
weights = self.proj(Z)
scores = self.sigmoid(weights)
new_h = self.top_k_graph(scores, h, self.k)
return new_h
def top_k_graph(self, scores, h, k):
"""
args
=====
scores: attention-based weights (#bs, #node, 1)
h: graph data (#bs, #node, #dim)
k: ratio of remaining nodes, (float)
returns
=====
h: graph pool applied data (#bs, #node', #dim)
"""
n_nodes = max(int(h.size(1) * k), 2)
n_feat = h.size(2)
_, idx = torch.topk(scores, n_nodes, dim=1)
idx = idx.expand(-1, -1, n_feat)
h = h * scores
h = torch.gather(h, 1, idx)
return h
class CONV(nn.Module):
@staticmethod
def to_mel(hz):
return 2595 * np.log10(1 + hz / 700)
@staticmethod
def to_hz(mel):
return 700 * (10**(mel / 2595) - 1)
def __init__(self,
out_channels,
kernel_size,
sample_rate=16000,
in_channels=1,
stride=1,
padding=0,
dilation=1,
bias=False,
groups=1,
mask=False):
super().__init__()
if in_channels != 1:
msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
in_channels)
raise ValueError(msg)
self.out_channels = out_channels
self.kernel_size = kernel_size
self.sample_rate = sample_rate
# Forcing the filters to be odd (i.e, perfectly symmetrics)
if kernel_size % 2 == 0:
self.kernel_size = self.kernel_size + 1
self.stride = stride
self.padding = padding
self.dilation = dilation
self.mask = mask
if bias:
raise ValueError('SincConv does not support bias.')
if groups > 1:
raise ValueError('SincConv does not support groups.')
NFFT = 512
f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
fmel = self.to_mel(f)
fmelmax = np.max(fmel)
fmelmin = np.min(fmel)
filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
filbandwidthsf = self.to_hz(filbandwidthsmel)
self.mel = filbandwidthsf
self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
(self.kernel_size - 1) / 2 + 1)
self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
for i in range(len(self.mel) - 1):
fmin = self.mel[i]
fmax = self.mel[i + 1]
hHigh = (2*fmax/self.sample_rate) * \
np.sinc(2*fmax*self.hsupp/self.sample_rate)
hLow = (2*fmin/self.sample_rate) * \
np.sinc(2*fmin*self.hsupp/self.sample_rate)
hideal = hHigh - hLow
self.band_pass[i, :] = Tensor(np.hamming(
self.kernel_size)) * Tensor(hideal)
def forward(self, x, mask=False):
band_pass_filter = self.band_pass.clone().to(x.device)
if mask:
A = np.random.uniform(0, 20)
A = int(A)
A0 = random.randint(0, band_pass_filter.shape[0] - A)
band_pass_filter[A0:A0 + A, :] = 0
else:
band_pass_filter = band_pass_filter
self.filters = (band_pass_filter).view(self.out_channels, 1,
self.kernel_size)
return F.conv1d(x,
self.filters,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
bias=None,
groups=1)
class Residual_block(nn.Module):
def __init__(self, nb_filts, first=False):
super().__init__()
self.first = first
if not self.first:
self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(1, 1),
stride=1)
self.selu = nn.SELU(inplace=True)
self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
out_channels=nb_filts[1],
kernel_size=(2, 3),
padding=(0, 1),
stride=1)
if nb_filts[0] != nb_filts[1]:
self.downsample = True
self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
out_channels=nb_filts[1],
padding=(0, 1),
kernel_size=(1, 3),
stride=1)
else:
self.downsample = False
self.mp = nn.MaxPool2d((1, 3)) # self.mp = nn.MaxPool2d((1,4))
def forward(self, x):
identity = x
if not self.first:
out = self.bn1(x)
out = self.selu(out)
else:
out = x
out = self.conv1(x)
# print('out',out.shape)
out = self.bn2(out)
out = self.selu(out)
# print('out',out.shape)
out = self.conv2(out)
#print('conv2 out',out.shape)
if self.downsample:
identity = self.conv_downsample(identity)
out += identity
out = self.mp(out)
return out
class Model(nn.Module):
def __init__(self, d_args):
super().__init__()
self.d_args = d_args
filts = d_args["filts"]
self.conv_time = CONV(out_channels=filts[0],
kernel_size=d_args["first_conv"],
in_channels=1)
self.first_bn = nn.BatchNorm2d(num_features=1)
self.selu = nn.SELU(inplace=True)
self.encoder_T = nn.Sequential(
nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
nn.Sequential(Residual_block(nb_filts=filts[2])),
nn.Sequential(Residual_block(nb_filts=filts[3])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])))
self.encoder_S = nn.Sequential(
nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
nn.Sequential(Residual_block(nb_filts=filts[2])),
nn.Sequential(Residual_block(nb_filts=filts[3])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])),
nn.Sequential(Residual_block(nb_filts=filts[4])))
self.GAT_layer_T = GraphAttentionLayer(64, 32)
self.GAT_layer_S = GraphAttentionLayer(64, 32)
self.GAT_layer_ST = GraphAttentionLayer(32, 16)
self.pool_T = GraphPool(0.64, 32, 0.3)
self.pool_S = GraphPool(0.81, 32, 0.3)
self.pool_ST = GraphPool(0.64, 16, 0.3)
self.proj_T = nn.Linear(14, 12)
self.proj_S = nn.Linear(23, 12)
self.proj_ST = nn.Linear(16, 1)
self.out_layer = nn.Linear(7, 2)
def forward(self, x, Freq_aug=False):
nb_samp1 = x.shape[0]
len_seq = x.shape[1]
x = x.view(nb_samp1, 1, len_seq)
x = self.conv_time(x, mask=Freq_aug)
x = x.unsqueeze(dim=1)
x = F.max_pool2d(torch.abs(x), (3, 3))
x = self.first_bn(x)
x = self.selu(x)
e_T = self.encoder_T(x) # (#bs, #filt, #spec, #seq)
e_T, _ = torch.max(torch.abs(e_T), dim=3) # max along time
gat_T = self.GAT_layer_T(e_T.transpose(1, 2))
pool_T = self.pool_T(gat_T) # (#bs, #node, #dim)
out_T = self.proj_T(pool_T.transpose(1, 2))
e_S = self.encoder_S(x)
e_S, _ = torch.max(torch.abs(e_S), dim=2) # max along freq
gat_S = self.GAT_layer_S(e_S.transpose(1, 2))
pool_S = self.pool_S(gat_S)
out_S = self.proj_S(pool_S.transpose(1, 2))
gat_ST = torch.mul(out_T, out_S)
gat_ST = self.GAT_layer_ST(gat_ST.transpose(1, 2))
pool_ST = self.pool_ST(gat_ST)
proj_ST = self.proj_ST(pool_ST).flatten(1)
output = self.out_layer(proj_ST)
return proj_ST, output
torch>=1.6.0
torchcontrib
numpy
soundfile
\ No newline at end of file
source /root/env_disc.sh
export TORCH_MHLO_OP_WHITE_LIST="aten::max;aten::batch_norm;aten::abs,aten::selu;prim::NumToTensor;aten::zeros_like;aten::size;aten::narrow;aten::cat;aten::selu_"
python3 main.py --eval --config ./config/AASIST-L.conf
python3 main_opt.py --eval --config ./config/AASIST-L.conf
"""
Utilization functions
"""
import os
import random
import sys
import numpy as np
import torch
def str_to_bool(val):
"""Convert a string representation of truth to true (1) or false (0).
Copied from the python implementation distutils.utils.strtobool
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
>>> str_to_bool('YES')
1
>>> str_to_bool('FALSE')
0
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return True
if val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
raise ValueError('invalid truth value {}'.format(val))
def cosine_annealing(step, total_steps, lr_max, lr_min):
"""Cosine Annealing for learning rate decay scheduler"""
return lr_min + (lr_max -
lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
def keras_decay(step, decay=0.0001):
"""Learning rate decay in Keras-style"""
return 1. / (1. + decay * step)
class SGDRScheduler(torch.optim.lr_scheduler._LRScheduler):
"""SGD with restarts scheduler"""
def __init__(self, optimizer, T0, T_mul, eta_min, last_epoch=-1):
self.Ti = T0
self.T_mul = T_mul
self.eta_min = eta_min
self.last_restart = 0
super().__init__(optimizer, last_epoch)
def get_lr(self):
T_cur = self.last_epoch - self.last_restart
if T_cur >= self.Ti:
self.last_restart = self.last_epoch
self.Ti = self.Ti * self.T_mul
T_cur = 0
return [
self.eta_min + (base_lr - self.eta_min) *
(1 + np.cos(np.pi * T_cur / self.Ti)) / 2
for base_lr in self.base_lrs
]
def _get_optimizer(model_parameters, optim_config):
"""Defines optimizer according to the given config"""
optimizer_name = optim_config['optimizer']
if optimizer_name == 'sgd':
optimizer = torch.optim.SGD(model_parameters,
lr=optim_config['base_lr'],
momentum=optim_config['momentum'],
weight_decay=optim_config['weight_decay'],
nesterov=optim_config['nesterov'])
elif optimizer_name == 'adam':
optimizer = torch.optim.Adam(model_parameters,
lr=optim_config['base_lr'],
betas=optim_config['betas'],
weight_decay=optim_config['weight_decay'],
amsgrad=str_to_bool(
optim_config['amsgrad']))
else:
print('Un-known optimizer', optimizer_name)
sys.exit()
return optimizer
def _get_scheduler(optimizer, optim_config):
"""
Defines learning rate scheduler according to the given config
"""
if optim_config['scheduler'] == 'multistep':
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=optim_config['milestones'],
gamma=optim_config['lr_decay'])
elif optim_config['scheduler'] == 'sgdr':
scheduler = SGDRScheduler(optimizer, optim_config['T0'],
optim_config['Tmult'],
optim_config['lr_min'])
elif optim_config['scheduler'] == 'cosine':
total_steps = optim_config['epochs'] * \
optim_config['steps_per_epoch']
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda step: cosine_annealing(
step,
total_steps,
1, # since lr_lambda computes multiplicative factor
optim_config['lr_min'] / optim_config['base_lr']))
elif optim_config['scheduler'] == 'keras_decay':
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: keras_decay(step))
else:
scheduler = None
return scheduler
def create_optimizer(model_parameters, optim_config):
"""Defines an optimizer and a scheduler"""
optimizer = _get_optimizer(model_parameters, optim_config)
scheduler = _get_scheduler(optimizer, optim_config)
return optimizer, scheduler
def seed_worker(worker_id):
"""
Used in generating seed for the worker of torch.utils.data.Dataloader
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
def set_seed(seed, config = None):
"""
set initial seed for reproduction
"""
if config is None:
raise ValueError("config should not be None")
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = str_to_bool(config["cudnn_deterministic_toggle"])
torch.backends.cudnn.benchmark = str_to_bool(config["cudnn_benchmark_toggle"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment