"examples/vscode:/vscode.git/clone" did not exist on "67bef2027cc461af5bbe73b3c0f35bb1350f5aa8"
Unverified Commit ea06688e authored by maqy's avatar maqy Committed by GitHub
Browse files

[Model] add model example EvolveGCN. (#3190)



* add evolveGCN example

* small fix

* fix defect

* fix defect
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent a536772e
# Implement EvolveGCN with DGL
paper link: [EvolveGCN](https://arxiv.org/abs/1902.10191)
official code: [IBM/EvolveGCN](https://github.com/IBM/EvolveGCN)
another implement: [pyG_temporal](https://github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/torch_geometric_temporal/nn/recurrent/evolvegcno.py)
## Dependency:
* dgl
* pandas
* numpy
## Run
* donwload Elliptic dataset from [kaggle](https://kaggle.com/ellipticco/elliptic-data-set)
* unzip the dataset into a raw directory, such as /home/Elliptic/elliptic_bitcoin_dataset/
* make a new dir to save processed data, such as /home/Elliptic/processed/
* run train.py by:
```bash
python train.py --raw-dir /home/Elliptic/elliptic_bitcoin_dataset/ --processed-dir /home/Elliptic/processed/
```
## Result
Using EvolveGCN-O can match the results of Fig.3 and Fig.4 in the paper.
(May need to run several times to get the average)
## Attention:
* Currently only the Elliptic dataset is used.
* EvolveGCN-H is not solid in Elliptic dataset, the official code is the same.
Official code result when use EvolveGCN-H:
1. set seed to 1234, finally result is :
> TEST epoch 189: TEST measures for class 1 - precision 0.3875 - recall 0.5714 - f1 0.4618
2. not set seed manually, run the same code three times:
> TEST epoch 168: TEST measures for class 1 - precision 0.3189 - recall 0.0680 - f1 0.1121
> TEST epoch 270: TEST measures for class 1 - precision 0.3517 - recall 0.3018 - f1 0.3249
> TEST epoch 455: TEST measures for class 1 - precision 0.2271 - recall 0.2995 - f1 0.2583
import os
import pandas
import numpy
import torch
import dgl
def process_raw_data(raw_dir, processed_dir):
r"""
Description
-----------
Preprocess Elliptic dataset like the EvolveGCN official instruction:
github.com/IBM/EvolveGCN/blob/master/elliptic_construction.md
The main purpose is to convert original idx to contiguous idx start at 0.
"""
oid_nid_path = os.path.join(processed_dir, 'oid_nid.npy')
id_label_path = os.path.join(processed_dir, 'id_label.npy')
id_time_features_path = os.path.join(processed_dir, 'id_time_features.npy')
src_dst_time_path = os.path.join(processed_dir, 'src_dst_time.npy')
if os.path.exists(oid_nid_path) and os.path.exists(id_label_path) and \
os.path.exists(id_time_features_path) and os.path.exists(src_dst_time_path):
print("The preprocessed data already exists, skip the preprocess stage!")
return
print("starting process raw data in {}".format(raw_dir))
id_label = pandas.read_csv(os.path.join(raw_dir, 'elliptic_txs_classes.csv'))
src_dst = pandas.read_csv(os.path.join(raw_dir, 'elliptic_txs_edgelist.csv'))
# elliptic_txs_features.csv has no header, and it has the same order idx with elliptic_txs_classes.csv
id_time_features = pandas.read_csv(os.path.join(raw_dir, 'elliptic_txs_features.csv'), header=None)
# get oldId_newId
oid_nid = id_label.loc[:, ['txId']]
oid_nid = oid_nid.rename(columns={'txId': 'originalId'})
oid_nid.insert(1, 'newId', range(len(oid_nid)))
# map classes unknown,1,2 to -1,1,0 and construct id_label. type 1 means illicit.
id_label = pandas.concat(
[oid_nid['newId'], id_label['class'].map({'unknown': -1.0, '1': 1.0, '2': 0.0})], axis=1)
# replace originalId to newId.
# Attention: the timestamp in features start at 1.
id_time_features[0] = oid_nid['newId']
# construct originalId2newId dict
oid_nid_dict = oid_nid.set_index(['originalId'])['newId'].to_dict()
# construct newId2timestamp dict
nid_time_dict = id_time_features.set_index([0])[1].to_dict()
# Map id in edgelist to newId, and add a timestamp to each edge.
# Attention: From the EvolveGCN official instruction, the timestamp with edgelist start at 0, rather than 1.
# see: github.com/IBM/EvolveGCN/blob/master/elliptic_construction.md
# Here we dose not follow the official instruction, which means timestamp with edgelist also start at 1.
# In EvolveGCN example, the edge timestamp will not be used.
#
# Note: in the dataset, src and dst node has the same timestamp, so it's easy to set edge's timestamp.
new_src = src_dst['txId1'].map(oid_nid_dict).rename('newSrc')
new_dst = src_dst['txId2'].map(oid_nid_dict).rename('newDst')
edge_time = new_src.map(nid_time_dict).rename('timestamp')
src_dst_time = pandas.concat([new_src, new_dst, edge_time], axis=1)
# save oid_nid, id_label, id_time_features, src_dst_time to disk. we can convert them to numpy.
# oid_nid: type int. id_label: type int. id_time_features: type float. src_dst_time: type int.
oid_nid = oid_nid.to_numpy(dtype=int)
id_label = id_label.to_numpy(dtype=int)
id_time_features = id_time_features.to_numpy(dtype=float)
src_dst_time = src_dst_time.to_numpy(dtype=int)
numpy.save(oid_nid_path, oid_nid)
numpy.save(id_label_path, id_label)
numpy.save(id_time_features_path, id_time_features)
numpy.save(src_dst_time_path, src_dst_time)
print("Process Elliptic raw data done, data has saved into {}".format(processed_dir))
class EllipticDataset:
def __init__(self, raw_dir, processed_dir, self_loop=True, reverse_edge=True):
self.raw_dir = raw_dir
self.processd_dir = processed_dir
self.self_loop = self_loop
self.reverse_edge = reverse_edge
def process(self):
process_raw_data(self.raw_dir, self.processd_dir)
id_time_features = torch.Tensor(numpy.load(os.path.join(self.processd_dir, 'id_time_features.npy')))
id_label = torch.IntTensor(numpy.load(os.path.join(self.processd_dir, 'id_label.npy')))
src_dst_time = torch.IntTensor(numpy.load(os.path.join(self.processd_dir, 'src_dst_time.npy')))
src = src_dst_time[:, 0]
dst = src_dst_time[:, 1]
# id_label[:, 0] is used to add self loop
if self.self_loop:
if self.reverse_edge:
g = dgl.graph(data=(torch.cat((src, dst, id_label[:, 0])), torch.cat((dst, src, id_label[:, 0]))),
num_nodes=id_label.shape[0])
g.edata['timestamp'] = torch.cat((src_dst_time[:, 2], src_dst_time[:, 2], id_time_features[:, 1].int()))
else:
g = dgl.graph(data=(torch.cat((src, id_label[:, 0])), torch.cat((dst, id_label[:, 0]))),
num_nodes=id_label.shape[0])
g.edata['timestamp'] = torch.cat((src_dst_time[:, 2], id_time_features[:, 1].int()))
else:
if self.reverse_edge:
g = dgl.graph(data=(torch.cat((src, dst)), torch.cat((dst, src))),
num_nodes=id_label.shape[0])
g.edata['timestamp'] = torch.cat((src_dst_time[:, 2], src_dst_time[:, 2]))
else:
g = dgl.graph(data=(src, dst),
num_nodes=id_label.shape[0])
g.edata['timestamp'] = src_dst_time[:, 2]
time_features = id_time_features[:, 1:]
label = id_label[:, 1]
g.ndata['label'] = label
g.ndata['feat'] = time_features
# used to construct time-based sub-graph.
node_mask_by_time = []
start_time = int(torch.min(id_time_features[:, 1]))
end_time = int(torch.max(id_time_features[:, 1]))
for i in range(start_time, end_time + 1):
node_mask = id_time_features[:, 1] == i
node_mask_by_time.append(node_mask)
return g, node_mask_by_time
@property
def num_classes(self):
r"""Number of classes for each node."""
return 2
import torch
import torch.nn as nn
from torch.nn import init
from dgl.nn.pytorch import GraphConv
from torch.nn.parameter import Parameter
class MatGRUCell(torch.nn.Module):
"""
GRU cell for matrix, similar to the official code.
Please refer to section 3.4 of the paper for the formula.
"""
def __init__(self, in_feats, out_feats):
super().__init__()
self.update = MatGRUGate(in_feats,
out_feats,
torch.nn.Sigmoid())
self.reset = MatGRUGate(in_feats,
out_feats,
torch.nn.Sigmoid())
self.htilda = MatGRUGate(in_feats,
out_feats,
torch.nn.Tanh())
def forward(self, prev_Q, z_topk=None):
if z_topk is None:
z_topk = prev_Q
update = self.update(z_topk, prev_Q)
reset = self.reset(z_topk, prev_Q)
h_cap = reset * prev_Q
h_cap = self.htilda(z_topk, h_cap)
new_Q = (1 - update) * prev_Q + update * h_cap
return new_Q
class MatGRUGate(torch.nn.Module):
"""
GRU gate for matrix, similar to the official code.
Please refer to section 3.4 of the paper for the formula.
"""
def __init__(self, rows, cols, activation):
super().__init__()
self.activation = activation
self.W = Parameter(torch.Tensor(rows, rows))
self.U = Parameter(torch.Tensor(rows, rows))
self.bias = Parameter(torch.Tensor(rows, cols))
self.reset_parameters()
def reset_parameters(self):
init.xavier_uniform_(self.W)
init.xavier_uniform_(self.U)
init.zeros_(self.bias)
def forward(self, x, hidden):
out = self.activation(self.W.matmul(x) + \
self.U.matmul(hidden) + \
self.bias)
return out
class TopK(torch.nn.Module):
"""
Similar to the official `egcn_h.py`. We only consider the node in a timestamp based subgraph,
so we need to pay attention to `K` should be less than the min node numbers in all subgraph.
Please refer to section 3.4 of the paper for the formula.
"""
def __init__(self, feats, k):
super().__init__()
self.scorer = Parameter(torch.Tensor(feats, 1))
self.reset_parameters()
self.k = k
def reset_parameters(self):
init.xavier_uniform_(self.scorer)
def forward(self, node_embs):
scores = node_embs.matmul(self.scorer) / self.scorer.norm().clamp(min=1e-6)
vals, topk_indices = scores.view(-1).topk(self.k)
out = node_embs[topk_indices] * torch.tanh(scores[topk_indices].view(-1, 1))
# we need to transpose the output
return out.t()
class EvolveGCNH(nn.Module):
def __init__(self, in_feats=166, n_hidden=76, num_layers=2, n_classes=2, classifier_hidden=510):
# default parameters follow the official config
super(EvolveGCNH, self).__init__()
self.num_layers = num_layers
self.pooling_layers = nn.ModuleList()
self.recurrent_layers = nn.ModuleList()
self.gnn_convs = nn.ModuleList()
self.gcn_weights_list = nn.ParameterList()
self.pooling_layers.append(TopK(in_feats, n_hidden))
# similar to EvolveGCNO
self.recurrent_layers.append(MatGRUCell(in_feats=in_feats, out_feats=n_hidden))
self.gcn_weights_list.append(Parameter(torch.Tensor(in_feats, n_hidden)))
self.gnn_convs.append(
GraphConv(in_feats=in_feats, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False))
for _ in range(num_layers - 1):
self.pooling_layers.append(TopK(n_hidden, n_hidden))
self.recurrent_layers.append(MatGRUCell(in_feats=n_hidden, out_feats=n_hidden))
self.gcn_weights_list.append(Parameter(torch.Tensor(n_hidden, n_hidden)))
self.gnn_convs.append(
GraphConv(in_feats=n_hidden, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False))
self.mlp = nn.Sequential(nn.Linear(n_hidden, classifier_hidden),
nn.ReLU(),
nn.Linear(classifier_hidden, n_classes))
self.reset_parameters()
def reset_parameters(self):
for gcn_weight in self.gcn_weights_list:
init.xavier_uniform_(gcn_weight)
def forward(self, g_list):
feature_list = []
for g in g_list:
feature_list.append(g.ndata['feat'])
for i in range(self.num_layers):
W = self.gcn_weights_list[i]
for j, g in enumerate(g_list):
X_tilde = self.pooling_layers[i](feature_list[j])
W = self.recurrent_layers[i](W, X_tilde)
feature_list[j] = self.gnn_convs[i](g, feature_list[j], weight=W)
return self.mlp(feature_list[-1])
class EvolveGCNO(nn.Module):
def __init__(self, in_feats=166, n_hidden=256, num_layers=2, n_classes=2, classifier_hidden=307):
# default parameters follow the official config
super(EvolveGCNO, self).__init__()
self.num_layers = num_layers
self.recurrent_layers = nn.ModuleList()
self.gnn_convs = nn.ModuleList()
self.gcn_weights_list = nn.ParameterList()
# In the paper, EvolveGCN-O use LSTM as RNN layer. According to the official code,
# EvolveGCN-O use GRU as RNN layer. Here we follow the official code.
# See: https://github.com/IBM/EvolveGCN/blob/90869062bbc98d56935e3d92e1d9b1b4c25be593/egcn_o.py#L53
# PS: I try to use torch.nn.LSTM directly,
# like [pyg_temporal](github.com/benedekrozemberczki/pytorch_geometric_temporal/blob/master/torch_geometric_temporal/nn/recurrent/evolvegcno.py)
# but the performance is worse than use torch.nn.GRU.
# PPS: I think torch.nn.GRU can't match the manually implemented GRU cell in the official code,
# we follow the official code here.
self.recurrent_layers.append(MatGRUCell(in_feats=in_feats, out_feats=n_hidden))
self.gcn_weights_list.append(Parameter(torch.Tensor(in_feats, n_hidden)))
self.gnn_convs.append(
GraphConv(in_feats=in_feats, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False))
for _ in range(num_layers - 1):
self.recurrent_layers.append(MatGRUCell(in_feats=n_hidden, out_feats=n_hidden))
self.gcn_weights_list.append(Parameter(torch.Tensor(n_hidden, n_hidden)))
self.gnn_convs.append(
GraphConv(in_feats=n_hidden, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False))
self.mlp = nn.Sequential(nn.Linear(n_hidden, classifier_hidden),
nn.ReLU(),
nn.Linear(classifier_hidden, n_classes))
self.reset_parameters()
def reset_parameters(self):
for gcn_weight in self.gcn_weights_list:
init.xavier_uniform_(gcn_weight)
def forward(self, g_list):
feature_list = []
for g in g_list:
feature_list.append(g.ndata['feat'])
for i in range(self.num_layers):
W = self.gcn_weights_list[i]
for j, g in enumerate(g_list):
# Attention: I try to use the below code to set gcn.weight(similar to pyG_temporal),
# but it doesn't work. It seems that the gradient function lost in this situation,
# more discussion see here: https://github.com/benedekrozemberczki/pytorch_geometric_temporal/issues/80
# ====================================================
# W = self.gnn_convs[i].weight[None, :, :]
# W, _ = self.recurrent_layers[i](W)
# self.gnn_convs[i].weight = nn.Parameter(W.squeeze())
# ====================================================
# Remove the following line of code, it will become `GCN`.
W = self.recurrent_layers[i](W)
feature_list[j] = self.gnn_convs[i](g, feature_list[j], weight=W)
return self.mlp(feature_list[-1])
import argparse
import time
import dgl
import torch
import torch.nn.functional as F
from dataset import EllipticDataset
from model import EvolveGCNO, EvolveGCNH
from utils import Measure
def train(args, device):
elliptic_dataset = EllipticDataset(raw_dir=args.raw_dir,
processed_dir=args.processed_dir,
self_loop=True,
reverse_edge=True)
g, node_mask_by_time = elliptic_dataset.process()
num_classes = elliptic_dataset.num_classes
cached_subgraph = []
cached_labeled_node_mask = []
for i in range(len(node_mask_by_time)):
# we add self loop edge when we construct full graph, not here
node_subgraph = dgl.node_subgraph(graph=g, nodes=node_mask_by_time[i])
cached_subgraph.append(node_subgraph.to(device))
valid_node_mask = node_subgraph.ndata['label'] >= 0
cached_labeled_node_mask.append(valid_node_mask)
if args.model == 'EvolveGCN-O':
model = EvolveGCNO(in_feats=int(g.ndata['feat'].shape[1]),
n_hidden=args.n_hidden,
num_layers=args.n_layers)
elif args.model == 'EvolveGCN-H':
model = EvolveGCNH(in_feats=int(g.ndata['feat'].shape[1]),
num_layers=args.n_layers)
else:
return NotImplementedError('Unsupported model {}'.format(args.model))
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# split train, valid, test(0-30,31-35,36-48)
# train/valid/test split follow the paper.
train_max_index = 30
valid_max_index = 35
test_max_index = 48
time_window_size = args.n_hist_steps
loss_class_weight = [float(w) for w in args.loss_class_weight.split(',')]
loss_class_weight = torch.Tensor(loss_class_weight).to(device)
train_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id)
valid_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id)
test_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id)
test_res_f1 = 0
for epoch in range(args.num_epochs):
model.train()
for i in range(time_window_size, train_max_index + 1):
g_list = cached_subgraph[i - time_window_size:i + 1]
predictions = model(g_list)
# get predictions which has label
predictions = predictions[cached_labeled_node_mask[i]]
labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long()
loss = F.cross_entropy(predictions, labels, weight=loss_class_weight)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_measure.append_measures(predictions, labels)
# get each epoch measures during training.
cl_precision, cl_recall, cl_f1 = train_measure.get_total_measure()
train_measure.update_best_f1(cl_f1, epoch)
# reset measures for next epoch
train_measure.reset_info()
print("Train Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
.format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1))
# eval
model.eval()
for i in range(train_max_index + 1, valid_max_index + 1):
g_list = cached_subgraph[i - time_window_size:i + 1]
predictions = model(g_list)
# get node predictions which has label
predictions = predictions[cached_labeled_node_mask[i]]
labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long()
valid_measure.append_measures(predictions, labels)
# get each epoch measure during eval.
cl_precision, cl_recall, cl_f1 = valid_measure.get_total_measure()
valid_measure.update_best_f1(cl_f1, epoch)
# reset measures for next epoch
valid_measure.reset_info()
print("Eval Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
.format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1))
# early stop
if epoch - valid_measure.target_best_f1_epoch >= args.patience:
print("Best eval Epoch {}, Cur Epoch {}".format(valid_measure.target_best_f1_epoch, epoch))
break
# if cur valid f1 score is best, do test
if epoch == valid_measure.target_best_f1_epoch:
print("###################Epoch {} Test###################".format(epoch))
for i in range(valid_max_index + 1, test_max_index + 1):
g_list = cached_subgraph[i - time_window_size:i + 1]
predictions = model(g_list)
# get predictions which has label
predictions = predictions[cached_labeled_node_mask[i]]
labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long()
test_measure.append_measures(predictions, labels)
# we get each subgraph measure when testing to match fig 4 in EvolveGCN paper.
cl_precisions, cl_recalls, cl_f1s = test_measure.get_each_timestamp_measure()
for index, (sub_p, sub_r, sub_f1) in enumerate(zip(cl_precisions, cl_recalls, cl_f1s)):
print(" Test | Time {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
.format(valid_max_index + index + 2, sub_p, sub_r, sub_f1))
# get each epoch measure during test.
cl_precision, cl_recall, cl_f1 = test_measure.get_total_measure()
test_measure.update_best_f1(cl_f1, epoch)
# reset measures for next test
test_measure.reset_info()
test_res_f1 = cl_f1
print(" Test | Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}"
.format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1))
print("Best test f1 is {}, in Epoch {}"
.format(test_measure.target_best_f1, test_measure.target_best_f1_epoch))
if test_measure.target_best_f1_epoch != valid_measure.target_best_f1_epoch:
print("The Epoch get best Valid measure not get the best Test measure, "
"please checkout the test result in Epoch {}, which f1 is {}"
.format(valid_measure.target_best_f1_epoch, test_res_f1))
if __name__ == "__main__":
argparser = argparse.ArgumentParser("EvolveGCN")
argparser.add_argument('--model', type=str, default='EvolveGCN-O',
help='We can choose EvolveGCN-O or EvolveGCN-H,'
'but the EvolveGCN-H performance on Elliptic dataset is not good.')
argparser.add_argument('--raw-dir', type=str,
default='/home/Elliptic/elliptic_bitcoin_dataset/',
help="Dir after unzip downloaded dataset, which contains 3 csv files.")
argparser.add_argument('--processed-dir', type=str,
default='/home/Elliptic/processed/',
help="Dir to store processed raw data.")
argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training.")
argparser.add_argument('--num-epochs', type=int, default=1000)
argparser.add_argument('--n-hidden', type=int, default=256)
argparser.add_argument('--n-layers', type=int, default=2)
argparser.add_argument('--n-hist-steps', type=int, default=5,
help="If it is set to 5, it means in the first batch,"
"we use historical data of 0-4 to predict the data of time 5.")
argparser.add_argument('--lr', type=float, default=0.001)
argparser.add_argument('--loss-class-weight', type=str, default='0.35,0.65',
help='Weight for loss function. Follow the official code,'
'we need to change it to 0.25, 0.75 when use EvolveGCN-H')
argparser.add_argument('--eval-class-id', type=int, default=1,
help="Class type to eval. On Elliptic, type 1(illicit) is the main interest.")
argparser.add_argument('--patience', type=int, default=100,
help="Patience for early stopping.")
args = argparser.parse_args()
if args.gpu >= 0:
device = torch.device('cuda:%d' % args.gpu)
else:
device = torch.device('cpu')
start_time = time.perf_counter()
train(args, device)
print("train time is: {}".format(time.perf_counter() - start_time))
def calculate_measure(tp, fn, fp):
# avoid nan
if tp == 0:
return 0, 0, 0
p = tp * 1.0 / (tp + fp)
r = tp * 1.0 / (tp + fn)
if (p + r) > 0:
f1 = 2.0 * (p * r) / (p + r)
else:
f1 = 0
return p, r, f1
class Measure(object):
def __init__(self, num_classes, target_class):
"""
Args:
num_classes: number of classes.
target_class: target class we focus on, used to print info and do early stopping.
"""
self.num_classes = num_classes
self.target_class = target_class
self.true_positives = {}
self.false_positives = {}
self.false_negatives = {}
self.target_best_f1 = 0.0
self.target_best_f1_epoch = 0
self.reset_info()
def reset_info(self):
"""
reset info after each epoch.
"""
self.true_positives = {cur_class: [] for cur_class in range(self.num_classes)}
self.false_positives = {cur_class: [] for cur_class in range(self.num_classes)}
self.false_negatives = {cur_class: [] for cur_class in range(self.num_classes)}
def append_measures(self, predictions, labels):
predicted_classes = predictions.argmax(dim=1)
for cl in range(self.num_classes):
cl_indices = (labels == cl)
pos = (predicted_classes == cl)
hits = (predicted_classes[cl_indices] == labels[cl_indices])
tp = hits.sum()
fn = hits.size(0) - tp
fp = pos.sum() - tp
self.true_positives[cl].append(tp.cpu())
self.false_negatives[cl].append(fn.cpu())
self.false_positives[cl].append(fp.cpu())
def get_each_timestamp_measure(self):
precisions = []
recalls = []
f1s = []
for i in range(len(self.true_positives[self.target_class])):
tp = self.true_positives[self.target_class][i]
fn = self.false_negatives[self.target_class][i]
fp = self.false_positives[self.target_class][i]
p, r, f1 = calculate_measure(tp, fn, fp)
precisions.append(p)
recalls.append(r)
f1s.append(f1)
return precisions, recalls, f1s
def get_total_measure(self):
tp = sum(self.true_positives[self.target_class])
fn = sum(self.false_negatives[self.target_class])
fp = sum(self.false_positives[self.target_class])
p, r, f1 = calculate_measure(tp, fn, fp)
return p, r, f1
def update_best_f1(self, cur_f1, cur_epoch):
if cur_f1 > self.target_best_f1:
self.target_best_f1 = cur_f1
self.target_best_f1_epoch = cur_epoch
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment