Commit af238596 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2392 failed with stages
in 0 seconds
# Copyright (c) 2024 westlake-repl
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
# SPDX-License-Identifier: MIT
# This file has been modified by Junyi Chen.
#
# Original file was released under MIT, with the full license text
# available at https://choosealicense.com/licenses/mit/.
#
# This modified file is released under the same license.
from torch.utils.data import Dataset
import torch
import pandas as pd
from transformers import AutoTokenizer
import logging
class BatchTextDataset(Dataset):
def __init__(self, config, dataload):
self.item_num = dataload.item_num
self.item_list = dataload.id2token['item_id']
self.max_text_length = config['MAX_TEXT_LENGTH']
self.device = config['device']
self.text_path = config['text_path']
self.text_keys = config['text_keys']
self.tokenizer = AutoTokenizer.from_pretrained(config['item_pretrain_dir'], trust_remote_code=True)
# self.pad_id = self.tokenizer.pad_token_id
# assert self.pad_id is not None, f"pad_token_id can't be {self.pad_id}"
self.item_prompt = config['item_prompt']
self.item_emb_token_n = config['item_emb_token_n']
self.logger = logging.getLogger()
self.load_content()
def __len__(self):
return self.item_num
def load_content(self):
self.env = pd.read_csv(self.text_path, delimiter=',', dtype={'item_id': str})
self.env = self.env[self.text_keys + ['item_id']]
self.env = self.env.set_index('item_id').T.to_dict()
self.logger.info(f"Text Item num: {len(self.env)}")
def __getitem__(self, index):
def process_item(item):
if item != self.item_list[0] and item not in self.env:
self.logger.info(f"{item} not in self.env")
item_i = self.env.get(item, {})
text_str = ""
if len(item_i):
text_str = f"{self.item_prompt}"
for key in self.text_keys:
value = item_i[key]
if value and str(value) != 'nan':
text_str += f"{key}: {value}"
ids = self.tokenizer.encode(text_str)
ids = ids[:self.max_text_length]
mask = [1] * len(ids)
return ids, mask
if index == 0 or index == self.item_num:
item_token_i = ""
else:
item_token_i = self.item_list[index]
pos_input_ids, pos_cu_input_lens, pos_position_ids = [], [], []
ids, _ = process_item(item_token_i)
pos_input_ids.extend(ids + [0] * self.item_emb_token_n)
pos_cu_input_lens.append(len(ids) + self.item_emb_token_n)
pos_position_ids.extend((torch.arange(len(ids) + self.item_emb_token_n) + (self.max_text_length - len(ids))).tolist())
outputs = {
"pos_item_ids": torch.as_tensor(index, dtype=torch.int64),
"pos_input_ids": torch.as_tensor(pos_input_ids, dtype=torch.int64),
"pos_cu_input_lens": torch.as_tensor(pos_cu_input_lens, dtype=torch.int64),
"pos_position_ids": torch.as_tensor(pos_position_ids, dtype=torch.int64)
}
return outputs
# Copyright (c) 2024 westlake-repl
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
# SPDX-License-Identifier: MIT
# This file has been modified by Junyi Chen.
#
# Original file was released under MIT, with the full license text
# available at https://choosealicense.com/licenses/mit/.
#
# This modified file is released under the same license.
import torch
import numpy as np
from torch.utils.data._utils.collate import default_collate
import re
try:
from torch._six import string_classes
except:
string_classes = str
import collections
np_str_obj_array_pattern = re.compile(r"[SaUO]")
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)
def customize_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: customize_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
return batch
def seq_eval_collate(batch):
item_seq = []
item_target = []
time_seq = []
history_i = []
for item in batch:
history_i.append(item[0])
item_seq.append(item[1])
item_target.append(item[2])
time_seq.append(item[3])
history_u = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_i)])
history_i = torch.cat(history_i)
item_seq = torch.tensor(item_seq) # [batch, len]
item_target = torch.tensor(item_target) # [batch]
time_seq = torch.tensor(time_seq) # [batch]
positive_u = torch.arange(item_seq.shape[0]) # [batch]
# return item_seq, None, positive_u, item_target
return item_seq, time_seq, (history_u, history_i), positive_u, item_target
def customize_rmpad_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
output = {}
for key in elem:
if any(['_input_ids' in key, '_cu_input_lens' in key, '_position_ids' in key]):
output[key] = torch.concat([d[key] for d in batch], dim=0)
else:
output[key] = customize_collate([d[key] for d in batch])
return output
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
return batch
# Copyright (c) 2024 westlake-repl
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
# SPDX-License-Identifier: MIT
# This file has been modified by Junyi Chen.
#
# Original file was released under MIT, with the full license text
# available at https://choosealicense.com/licenses/mit/.
#
# This modified file is released under the same license.
import torch
from torch.utils.data import Dataset
import numpy as np
import datetime
import pytz
class SeqEvalDataset(Dataset):
def __init__(self, config, dataload, phase='valid'):
self.dataload = dataload
self.max_item_list_length = config['MAX_ITEM_LIST_LENGTH_TEST'] if config['MAX_ITEM_LIST_LENGTH_TEST'] else config['MAX_ITEM_LIST_LENGTH']
self.user_seq = list(dataload.user_seq.values())
self.time_seq = list(dataload.time_seq.values())
self.use_time = config['use_time']
self.phase = phase
self.length = len(self.user_seq)
self.item_num = dataload.item_num
def __len__(self):
return self.length
def _padding_sequence(self, sequence, max_length):
sequence = list(sequence)
pad_len = max_length - len(sequence)
sequence = [0] * pad_len + sequence
sequence = sequence[-max_length:]
return sequence
def _padding_time_sequence(self, sequence, max_length):
sequence = list(sequence)
pad_len = max_length - len(sequence)
sequence = [0] * pad_len + sequence
sequence = sequence[-max_length:]
vq_time = []
for time in sequence:
dt = datetime.datetime.fromtimestamp(time, pytz.timezone('UTC'))
vq_time.append([dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second])
return vq_time
def __getitem__(self, index):
last_num = 2 if self.phase == 'valid' else 1
history_seq = self.user_seq[index][:-last_num]
item_seq = self._padding_sequence(history_seq, self.max_item_list_length)
item_target = self.user_seq[index][-last_num]
if self.use_time:
history_time_seq = self.time_seq[index][:-last_num]
else:
history_time_seq = []
time_seq = self._padding_time_sequence(history_time_seq, self.max_item_list_length)
return torch.tensor(history_seq), item_seq, item_target, time_seq # , item_length
# Copyright (c) 2024 westlake-repl
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
# SPDX-License-Identifier: MIT
# This file has been modified by Junyi Chen.
#
# Original file was released under MIT, with the full license text
# available at https://choosealicense.com/licenses/mit/.
#
# This modified file is released under the same license.
from asyncio.log import logger
from torch.utils.data import Dataset
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
import random
import datetime
import pytz
import math
import torch.distributed as dist
# 数据形式为 [[user_seq], [neg_item_seq]] , [mask]
class SEQTrainDataset(Dataset):
def __init__(self, config, dataload):
self.dataload = dataload
self.config = config
self.item_num = dataload.item_num
self.train_seq = dataload.train_feat['item_seq']
self.length = len(self.train_seq)
self.max_seq_length = config['MAX_ITEM_LIST_LENGTH']+1
self.device = config['device']
self.random_sample = True if config['loss'] and config['loss'] == 'nce' else False
self.num_negatives = config['num_negatives']
if self.num_negatives:
self.num_negatives = math.ceil(self.num_negatives / dist.get_world_size() / config['train_batch_size'])
logger.info(f"Use random sample {self.random_sample} for mask id")
def __len__(self):
return self.length
def _neg_sample(self, item_set):
item = random.randint(1, self.item_num - 1)
while item in item_set:
item = random.randint(1, self.item_num - 1)
return item
def _padding_sequence(self, sequence, max_length, random_sample=False):
pad_len = max_length - len(sequence)
if random_sample:
pad_seq = [self._neg_sample(sequence) for _ in range(pad_len)]
sequence = pad_seq + sequence
else:
sequence = [0] * pad_len + sequence
sequence = sequence[-max_length:]
return torch.tensor(sequence, dtype=torch.long)
def reconstruct_train_data(self, item_seq):
masked_index = []
neg_item = []
item_seq_len = len(item_seq)
for i in range(item_seq_len - 1):
neg_item.append(self._neg_sample(item_seq))
masked_index.append(1)
item_seq = self._padding_sequence(list(item_seq), self.max_seq_length, random_sample=self.random_sample)
if self.num_negatives:
neg_item = []
for _ in range(self.num_negatives):
neg_item.append(self._neg_sample(item_seq))
else:
neg_item = self._padding_sequence(neg_item, self.max_seq_length, random_sample=self.random_sample)
masked_index = self._padding_sequence(masked_index, self.max_seq_length-1)
return torch.as_tensor(item_seq, dtype=torch.int64), torch.as_tensor(neg_item, dtype=torch.int64), torch.as_tensor(masked_index, dtype=torch.int64)
def __getitem__(self, index):
# 最长长度为maxlen+1, 及若max_len是5
# 则存在 1,2,3,4,5,6序列,
# pos 2,3,4,5,6
# neg 0,8,9,7,9,8
# mask_index 1,1,1,1,1
item_seq = self.train_seq[index]
item_seq, neg_item, masked_index = self.reconstruct_train_data(item_seq)
return item_seq, neg_item, masked_index
class TextSEQTrainDataset(Dataset):
def __init__(self, config, dataload):
self.dataload = dataload
self.config = config
self.item_num = dataload.item_num
self.train_seq = dataload.train_feat['item_seq']
self.length = len(self.train_seq)
self.train_time_seq = dataload.train_feat['time_seq']
self.id2token = dataload.id2token['item_id']
self.max_seq_length = config['MAX_ITEM_LIST_LENGTH']+1
self.max_text_length = config['MAX_TEXT_LENGTH']
self.device = config['device']
self.text_path = config['text_path']
self.text_keys = config['text_keys']
self.tokenizer = AutoTokenizer.from_pretrained(config['item_pretrain_dir'], trust_remote_code=True)
# self.pad_id = self.tokenizer.pad_token_id
# assert self.pad_id is not None, f"pad_token_id can't be {self.pad_id}"
self.item_prompt = config['item_prompt']
self.item_emb_token_n = config['item_emb_token_n']
self.num_negatives = config['num_negatives']
self.random_sample = True if config['loss'] and config['loss'] == 'nce' else False
if self.num_negatives:
self.num_negatives = math.ceil(self.num_negatives / dist.get_world_size() / config['train_batch_size']) # for llm only
logger.info(f"Use random sample {self.random_sample} for mask id")
logger.info(f"Text path: {self.text_path}")
logger.info(f"Text keys: {self.text_keys}")
logger.info(f"Item prompt: {self.item_prompt}")
self.load_content()
def __len__(self):
return self.length
def load_content(self):
self.env = pd.read_csv(self.text_path, delimiter=',', dtype={'item_id': str})
self.env = self.env[self.text_keys + ['item_id']]
self.env = self.env.set_index('item_id').T.to_dict()
logger.info(f"Text Item num: {len(self.env)}")
def _neg_sample(self, item_set):
item = random.randint(1, self.item_num - 1)
while item in item_set:
item = random.randint(1, self.item_num - 1)
return item
def _padding_sequence(self, sequence, max_length, random_sample=False):
pad_len = max_length - len(sequence)
if random_sample:
pad_seq = [self._neg_sample(sequence) for _ in range(pad_len)]
sequence = pad_seq + sequence
else:
sequence = [0] * pad_len + sequence
sequence = sequence[-max_length:]
return torch.tensor(sequence, dtype=torch.long)
def reconstruct_train_data(self, item_seq):
masked_index = []
neg_item = []
item_seq_len = len(item_seq)
for i in range(item_seq_len - 1):
neg_item.append(self._neg_sample(item_seq))
masked_index.append(1)
item_seq = self._padding_sequence(list(item_seq), self.max_seq_length, random_sample=self.random_sample)
masked_index = self._padding_sequence(masked_index, self.max_seq_length-1)
if self.num_negatives:
neg_item = []
for _ in range(self.num_negatives):
neg_item.append(self._neg_sample([]))
else:
neg_item = self._padding_sequence(neg_item, self.max_seq_length, random_sample=self.random_sample)
return item_seq, neg_item, masked_index
def _padding_time_sequence(self, sequence, max_length):
pad_len = max_length - len(sequence)
sequence = [0] * pad_len + sequence
sequence = sequence[-max_length:]
vq_time = []
for time in sequence:
dt = datetime.datetime.fromtimestamp(time, pytz.timezone('UTC'))
vq_time.append([dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second])
return torch.tensor(vq_time, dtype=torch.long)
def __getitem__(self, index):
item_seq = self.train_seq[index]
item_seq, neg_item, masked_index = self.reconstruct_train_data(item_seq)
time_seq = self.train_time_seq[index]
time_seq = self._padding_time_sequence(list(time_seq), self.max_seq_length)
item_seq_token = self.id2token[item_seq]
neg_items_token = self.id2token[neg_item]
pos_input_ids, pos_cu_input_lens, pos_position_ids = [], [], []
neg_input_ids, neg_cu_input_lens, neg_position_ids = [], [], []
def process_item(item):
if item != self.id2token[0] and item not in self.env:
# assert item in self.env, f"{item}"
logger.info(f"{item} not in self.env")
item_i = self.env.get(item, {})
text_str = ""
if len(item_i):
text_str = f"{self.item_prompt}"
for key in self.text_keys:
value = item_i[key]
if value and str(value) != 'nan':
text_str += f"{key}: {value}"
ids = self.tokenizer.encode(text_str)
ids = ids[:self.max_text_length]
mask = [1] * len(ids)
return ids, mask
for item in item_seq_token:
ids, _ = process_item(item)
pos_input_ids.extend(ids + [0] * self.item_emb_token_n)
pos_cu_input_lens.append(len(ids) + self.item_emb_token_n)
pos_position_ids.extend((torch.arange(len(ids) + self.item_emb_token_n) + (self.max_text_length - len(ids))).tolist())
for neg in neg_items_token:
ids, _ = process_item(neg)
neg_input_ids.extend(ids + [0] * self.item_emb_token_n)
neg_cu_input_lens.append(len(ids) + self.item_emb_token_n)
neg_position_ids.extend((torch.arange(len(ids) + self.item_emb_token_n) + (self.max_text_length - len(ids))).tolist())
outputs = {
"pos_item_ids": torch.as_tensor(item_seq, dtype=torch.int64),
"neg_item_ids": torch.as_tensor(neg_item, dtype=torch.int64),
"pos_input_ids": torch.as_tensor(pos_input_ids, dtype=torch.int64),
"pos_cu_input_lens": torch.as_tensor(pos_cu_input_lens, dtype=torch.int64),
"pos_position_ids": torch.as_tensor(pos_position_ids, dtype=torch.int64),
"neg_input_ids": torch.as_tensor(neg_input_ids, dtype=torch.int64),
"neg_cu_input_lens": torch.as_tensor(neg_cu_input_lens, dtype=torch.int64),
"neg_position_ids": torch.as_tensor(neg_position_ids, dtype=torch.int64),
"attention_mask": torch.as_tensor(masked_index, dtype=torch.int64),
"time_ids": torch.as_tensor(time_seq, dtype=torch.int64),
}
return outputs
# Copyright (c) 2024 westlake-repl
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
# SPDX-License-Identifier: MIT
# This file has been modified by Junyi Chen.
#
# Original file was released under MIT, with the full license text
# available at https://choosealicense.com/licenses/mit/.
#
# This modified file is released under the same license.
import copy
import importlib
import os
import pickle
from logging import getLogger
from REC.data.dataset import *
from REC.utils import set_color
from functools import partial
from .dataload import Data
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import math
import copy
def load_data(config):
dataload = Data(config)
return dataload
def bulid_dataloader(config, dataload):
'''
split dataset, generate user history sequence, train/valid/test dataset
'''
dataset_dict = {
'SASRec': ('SEQTrainDataset', 'SeqEvalDataset', 'seq_eval_collate'),
'HSTU': ('SEQTrainDataset', 'SeqEvalDataset', 'seq_eval_collate'),
'LLMIDRec': ('SEQTrainDataset', 'SeqEvalDataset', 'seq_eval_collate'),
'HLLM': (('TextSEQTrainDataset', 'customize_rmpad_collate'), 'SeqEvalDataset', 'seq_eval_collate')
}
model_name = config['model']
dataload.build()
dataset_module = importlib.import_module('REC.data.dataset')
train_set_name, test_set_name, collate_fn_name = dataset_dict[model_name]
if isinstance(train_set_name, tuple):
train_set_class = getattr(dataset_module, train_set_name[0])
train_collate_fn = getattr(dataset_module, train_set_name[1])
else:
train_set_class = getattr(dataset_module, train_set_name)
train_collate_fn = None
test_set_class = getattr(dataset_module, test_set_name)
eval_collate_fn = getattr(dataset_module, collate_fn_name)
train_data = train_set_class(config, dataload)
valid_data = test_set_class(config, dataload, phase='valid')
test_data = test_set_class(config, dataload, phase='test')
logger = getLogger()
logger.info(
set_color('[Training]: ', 'pink') + set_color('train_batch_size', 'cyan') + ' = ' +
set_color(f'[{config["train_batch_size"]}]', 'yellow')
)
logger.info(
set_color('[Evaluation]: ', 'pink') + set_color('eval_batch_size', 'cyan') + ' = ' +
set_color(f'[{config["eval_batch_size"]}]', 'yellow')
)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
valid_sampler = NonConsecutiveSequentialDistributedSampler(valid_data)
test_sampler = NonConsecutiveSequentialDistributedSampler(test_data)
num_workers = 8
rank = torch.distributed.get_rank()
seed = torch.initial_seed()
init_fn = partial(
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed
)
if train_collate_fn:
train_loader = DataLoader(train_data, batch_size=config['train_batch_size'], num_workers=num_workers,
pin_memory=True, sampler=train_sampler, collate_fn=train_collate_fn, worker_init_fn=init_fn)
else:
train_loader = DataLoader(train_data, batch_size=config['train_batch_size'], num_workers=num_workers,
pin_memory=True, sampler=train_sampler, worker_init_fn=init_fn)
valid_loader = DataLoader(valid_data, batch_size=config['eval_batch_size'], num_workers=num_workers,
pin_memory=True, sampler=valid_sampler, collate_fn=eval_collate_fn)
test_loader = DataLoader(test_data, batch_size=config['eval_batch_size'], num_workers=num_workers,
pin_memory=True, sampler=test_sampler, collate_fn=eval_collate_fn)
return train_loader, valid_loader, test_loader
def worker_init_fn(worker_id, num_workers, rank, seed):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
def worker_init_reset_seed(worker_id):
initial_seed = torch.initial_seed() % 2 ** 31
worker_seed = initial_seed + worker_id + torch.distributed.get_rank()
random.seed(worker_seed)
np.random.seed(worker_seed)
class NonConsecutiveSequentialDistributedSampler(torch.utils.data.sampler.Sampler):
def __init__(self, dataset, rank=None, num_replicas=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.total_size = len(self.dataset)
self.num_samples = math.ceil(
(self.total_size-self.rank)/self.num_replicas
)
def __iter__(self):
indices = list(range(len(self.dataset)))
indices = indices[self.rank:self.total_size:self.num_replicas]
return iter(indices)
def __len__(self):
return self.num_samples
class ConsecutiveSequentialDistributedSampler(torch.utils.data.sampler.Sampler):
def __init__(self, dataset, batch_size, rank=None, num_replicas=None):
if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = torch.distributed.get_world_size()
if rank is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = torch.distributed.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.batch_size = batch_size
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.batch_size / self.num_replicas)) * self.batch_size
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += [indices[-1]] * (self.total_size - len(indices))
# subsample
indices = indices[self.rank * self.num_samples: (self.rank + 1) * self.num_samples]
return iter(indices)
def __len__(self):
return self.num_samples
from .base_metric import *
from .metrics import *
from .evaluator import *
from .register import *
from .collector import *
# Copyright (c) 2024 westlake-repl
# SPDX-License-Identifier: MIT
import torch
from REC.utils import EvaluatorType
class AbstractMetric(object):
""":class:`AbstractMetric` is the base object of all metrics. If you want to
implement a metric, you should inherit this class.
Args:
config (Config): the config of evaluator.
"""
smaller = False
def __init__(self, config):
self.decimal_place = config['metric_decimal_place'] + 2 if config['metric_decimal_place'] else 7
def calculate_metric(self, dataobject):
"""Get the dictionary of a metric.
Args:
dataobject(DataStruct): it contains all the information needed to calculate metrics.
Returns:
dict: such as ``{'metric@10': 3153, 'metric@20': 0.3824}``
"""
raise NotImplementedError('Method [calculate_metric] should be implemented.')
class TopkMetric(AbstractMetric):
""":class:`TopkMetric` is a base object of top-k metrics. If you want to
implement an top-k metric, you can inherit this class.
Args:
config (Config): The config of evaluator.
"""
metric_type = EvaluatorType.RANKING
metric_need = ['rec.topk']
def __init__(self, config):
super().__init__(config)
self.topk = config['topk']
def used_info(self, dataobject):
"""Get the bool matrix indicating whether the corresponding item is positive
and number of positive items for each user.
"""
rec_mat = dataobject.get('rec.topk')
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()
def topk_result(self, metric, value):
"""Match the metric value to the `k` and put them in `dictionary` form.
Args:
metric(str): the name of calculated metric.
value(numpy.ndarray): metrics for each user, including values from `metric@1` to `metric@max(self.topk)`.
Returns:
dict: metric values required in the configuration.
"""
metric_dict = {}
avg_result = value.sum(axis=0)
for k in self.topk:
key = '{}@{}'.format(metric, k)
# metric_dict[key] = round(avg_result[k - 1], self.decimal_place)
metric_dict[key] = avg_result[k - 1]
return metric_dict
def metric_info(self, pos_index, pos_len=None):
"""Calculate the value of the metric.
Args:
pos_index(numpy.ndarray): a bool matrix, shape of ``n_users * max(topk)``. The item with the (j+1)-th \
highest score of i-th user is positive if ``pos_index[i][j] == True`` and negative otherwise.
pos_len(numpy.ndarray): a vector representing the number of positive items per user, shape of ``(n_users,)``.
Returns:
numpy.ndarray: metrics for each user, including values from `metric@1` to `metric@max(self.topk)`.
"""
raise NotImplementedError('Method [metric_info] of top-k metric should be implemented.')
class LossMetric(AbstractMetric):
""":class:`LossMetric` is a base object of loss based metrics and AUC. If you want to
implement an loss based metric, you can inherit this class.
Args:
config (Config): The config of evaluator.
"""
metric_type = EvaluatorType.VALUE
metric_need = ['rec.score', 'data.label']
def __init__(self, config):
super().__init__(config)
def used_info(self, dataobject):
"""Get scores that model predicted and the ground truth."""
preds = dataobject.get('rec.score')
trues = dataobject.get('data.label')
return preds.squeeze(-1).numpy(), trues.squeeze(-1).numpy()
def output_metric(self, metric, dataobject):
preds, trues = self.used_info(dataobject)
result = self.metric_info(preds, trues)
return {metric: round(result, self.decimal_place)}
def metric_info(self, preds, trues):
"""Calculate the value of the metric.
Args:
preds (numpy.ndarray): the scores predicted by model, a one-dimensional vector.
trues (numpy.ndarray): the label of items, which has the same shape as ``preds``.
Returns:
float: The value of the metric.
"""
raise NotImplementedError('Method [metric_info] of loss-based metric should be implemented.')
# Copyright (c) 2024 westlake-repl
# SPDX-License-Identifier: MIT
from .register import Register
import torch
import copy
import numpy as np
class DataStruct(object):
def __init__(self):
self._data_dict = {}
def __getitem__(self, name: str):
return self._data_dict[name]
def __setitem__(self, name: str, value):
self._data_dict[name] = value
def __delitem__(self, name: str):
self._data_dict.pop(name)
def __contains__(self, key: str):
return key in self._data_dict
def get(self, name: str):
if name not in self._data_dict:
raise IndexError("Can not load the data without registration !")
return self[name]
def set(self, name: str, value):
self._data_dict[name] = value
def update_tensor(self, name: str, value: torch.Tensor):
if name not in self._data_dict:
self._data_dict[name] = value.cpu().clone().detach()
else:
if not isinstance(self._data_dict[name], torch.Tensor):
raise ValueError("{} is not a tensor.".format(name))
self._data_dict[name] = torch.cat((self._data_dict[name], value.cpu().clone().detach()), dim=0)
def __str__(self):
data_info = '\nContaining:\n'
for data_key in self._data_dict.keys():
data_info += data_key + '\n'
return data_info
class Collector(object):
"""The collector is used to collect the resource for evaluator.
As the evaluation metrics are various, the needed resource not only contain the recommended result
but also other resource from data and model. They all can be collected by the collector during the training
and evaluation process.
This class is only used in Trainer.
"""
def __init__(self, config):
self.config = config
self.data_struct = DataStruct()
self.register = Register(config)
self.full = True
self.topk = self.config['topk']
self.device = self.config['device']
def data_collect(self, train_data):
""" Collect the evaluation resource from training data.
Args:
train_data (AbstractDataLoader): the training dataloader which contains the training data.
"""
if self.register.need('data.num_items'):
item_id = 'item_id'
self.data_struct.set('data.num_items', train_data.dataset.item_num)
if self.register.need('data.num_users'):
user_id = 'user_id'
self.data_struct.set('data.num_users', train_data.dataset.user_num)
if self.register.need('data.count_items'):
self.data_struct.set('data.count_items', train_data.dataset.item_counter)
if self.register.need('data.count_users'):
self.data_struct.set('data.count_items', train_data.dataset.user_counter)
def _average_rank(self, scores):
"""Get the ranking of an ordered tensor, and take the average of the ranking for positions with equal values.
Args:
scores(tensor): an ordered tensor, with size of `(N, )`
Returns:
torch.Tensor: average_rank
Example:
>>> average_rank(tensor([[1,2,2,2,3,3,6],[2,2,2,2,4,5,5]]))
tensor([[1.0000, 3.0000, 3.0000, 3.0000, 5.5000, 5.5000, 7.0000],
[2.5000, 2.5000, 2.5000, 2.5000, 5.0000, 6.5000, 6.5000]])
Reference:
https://github.com/scipy/scipy/blob/v0.17.1/scipy/stats/stats.py#L5262-L5352
"""
length, width = scores.shape
true_tensor = torch.full((length, 1), True, dtype=torch.bool, device=self.device)
obs = torch.cat([true_tensor, scores[:, 1:] != scores[:, :-1]], dim=1)
# bias added to dense
bias = torch.arange(0, length, device=self.device).repeat(width).reshape(width, -1). \
transpose(1, 0).reshape(-1)
dense = obs.view(-1).cumsum(0) + bias
# cumulative counts of each unique value
count = torch.where(torch.cat([obs, true_tensor], dim=1))[1]
# get average rank
avg_rank = .5 * (count[dense] + count[dense - 1] + 1).view(length, -1)
return avg_rank
def eval_batch_collect(
self, scores_tensor: torch.Tensor, positive_u: torch.Tensor, positive_i: torch.Tensor, interaction=None
):
""" Collect the evaluation resource from batched eval data and batched model output.
Args:
scores_tensor (Torch.Tensor): the output tensor of model with the shape of `(N, )`
interaction(Interaction): batched eval data.
positive_u(Torch.Tensor): the row index of positive items for each user.
positive_i(Torch.Tensor): the positive item id for each user.
"""
if self.register.need('rec.items'):
# get topk
_, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k
self.data_struct.update_tensor('rec.items', topk_idx)
if self.register.need('rec.topk'):
_, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k
pos_matrix = torch.zeros_like(scores_tensor, dtype=torch.int)
pos_matrix[positive_u, positive_i] = 1
pos_len_list = pos_matrix.sum(dim=1, keepdim=True)
pos_idx = torch.gather(pos_matrix, dim=1, index=topk_idx)
result = torch.cat((pos_idx, pos_len_list), dim=1)
self.data_struct.update_tensor('rec.topk', result)
if self.register.need('rec.meanrank'):
desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True)
# get the index of positive items in the ranking list
pos_matrix = torch.zeros_like(scores_tensor)
pos_matrix[positive_u, positive_i] = 1
pos_index = torch.gather(pos_matrix, dim=1, index=desc_index)
avg_rank = self._average_rank(desc_scores)
pos_rank_sum = torch.where(pos_index == 1, avg_rank, torch.zeros_like(avg_rank)).sum(dim=-1, keepdim=True)
pos_len_list = pos_matrix.sum(dim=1, keepdim=True)
user_len_list = desc_scores.argmin(dim=1, keepdim=True)
result = torch.cat((pos_rank_sum, user_len_list, pos_len_list), dim=1)
self.data_struct.update_tensor('rec.meanrank', result)
if self.register.need('rec.score'):
self.data_struct.update_tensor('rec.score', scores_tensor)
# if self.register.need('data.label'):
# self.label_field = self.config['LABEL_FIELD']
# self.data_struct.update_tensor('data.label', interaction[self.label_field].to(self.device))
def model_collect(self, model: torch.nn.Module):
""" Collect the evaluation resource from model.
Args:
model (nn.Module): the trained recommendation model.
"""
pass
# TODO:
def eval_collect(self, eval_pred: torch.Tensor, data_label: torch.Tensor):
""" Collect the evaluation resource from total output and label.
It was designed for those models that can not predict with batch.
Args:
eval_pred (torch.Tensor): the output score tensor of model.
data_label (torch.Tensor): the label tensor.
"""
if self.register.need('rec.score'):
self.data_struct.update_tensor('rec.score', eval_pred)
if self.register.need('data.label'):
self.label_field = self.config['LABEL_FIELD']
self.data_struct.update_tensor('data.label', data_label.to(self.device))
def distributed_concat(self, tensor, num_total_examples):
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
return concat[:num_total_examples]
def get_data_struct(self):
""" Get all the evaluation resource that been collected.
And reset some of outdated resource.
"""
returned_struct = copy.deepcopy(self.data_struct)
for key in ['rec.topk', 'rec.meanrank', 'rec.score', 'rec.items', 'data.label']:
if key in self.data_struct:
del self.data_struct[key]
return returned_struct
# Copyright (c) 2024 westlake-repl
# SPDX-License-Identifier: MIT
from .register import metrics_dict
from .collector import DataStruct
from collections import OrderedDict
class Evaluator(object):
"""Evaluator is used to check parameter correctness, and summarize the results of all metrics.
"""
def __init__(self, config):
self.config = config
self.metrics = [metric.lower() for metric in self.config['metrics']]
self.metric_class = {}
for metric in self.metrics:
self.metric_class[metric] = metrics_dict[metric](self.config)
def evaluate(self, dataobject: DataStruct):
"""calculate all the metrics. It is called at the end of each epoch
Args:
dataobject (DataStruct): It contains all the information needed for metrics.
Returns:
collections.OrderedDict: such as ``{'hit@20': 0.3824, 'recall@20': 0.0527, 'hit@10': 0.3153, 'recall@10': 0.0329, 'gauc': 0.9236}``
"""
result_dict = OrderedDict()
for metric in self.metrics:
metric_val = self.metric_class[metric].calculate_metric(dataobject)
result_dict.update(metric_val)
return result_dict
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment