Commit af238596 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2392 failed with stages
in 0 seconds
# Copyright (c) 2024 westlake-repl
# SPDX-License-Identifier: MIT
import inspect
import sys
def cluster_info(module_name):
"""Collect information of all metrics, including:
- ``metric_need``: Information needed to calculate this metric, the combination of ``rec.items, rec.topk,
rec.meanrank, rec.score, data.num_items, data.num_users, data.count_items, data.count_users, data.label``.
- ``metric_type``: Whether the scores required by metric are grouped by user, range in ``EvaluatorType.RANKING``
and ``EvaluatorType.VALUE``.
- ``smaller``: Whether the smaller metric value represents better performance,
range in ``True`` and ``False``, default to ``False``.
Note:
For ``metric_type``: in current RecBole, all the "grouped-score" metrics are ranking-based and all the
"non-grouped-score" metrics are value-based. To keep with our paper, we adopted the more formal terms:
``RANKING`` and ``VALUE``.
Args:
module_name (str): the name of module ``recbole.evaluator.metrics``.
Returns:
dict: Three dictionaries containing the above information
and a dictionary matching metric names to metric classes.
"""
smaller_m = []
m_dict, m_info, m_types = {}, {}, {}
metric_class = inspect.getmembers(
sys.modules[module_name], lambda x: inspect.isclass(x) and x.__module__ == module_name
)
for name, metric_cls in metric_class:
name = name.lower()
m_dict[name] = metric_cls
if hasattr(metric_cls, 'metric_need'):
m_info[name] = metric_cls.metric_need
else:
raise AttributeError(f"Metric '{name}' has no attribute [metric_need].")
if hasattr(metric_cls, 'metric_type'):
m_types[name] = metric_cls.metric_type
else:
raise AttributeError(f"Metric '{name}' has no attribute [metric_type].")
if metric_cls.smaller is True:
smaller_m.append(name)
return smaller_m, m_info, m_types, m_dict
metric_module_name = 'REC.evaluator.metrics'
smaller_metrics, metric_information, metric_types, metrics_dict = cluster_info(metric_module_name)
class Register(object):
""" Register module load the registry according to the metrics in config.
It is a member of DataCollector.
The DataCollector collect the resource that need for Evaluator under the guidance of Register
"""
def __init__(self, config):
self.config = config
self.metrics = [metric.lower() for metric in self.config['metrics']]
self._build_register()
def _build_register(self):
for metric in self.metrics:
metric_needs = metric_information[metric]
for info in metric_needs:
setattr(self, info, True)
def has_metric(self, metric: str):
if metric.lower() in self.metrics:
return True
else:
return False
def need(self, key: str):
if hasattr(self, key):
return getattr(self, key)
return False
# Copyright (c) 2024 westlake-repl
# SPDX-License-Identifier: MIT
import itertools
import numpy as np
import torch
def pad_sequence(sequences, len_list, pad_to=None, padding_value=0):
"""pad sequences to a matrix
Args:
sequences (list): list of variable length sequences.
len_list (list): the length of the tensors in the sequences
pad_to (int, optional): if pad_to is not None, the sequences will pad to the length you set,
else the sequence will pad to the max length of the sequences.
padding_value (int, optional): value for padded elements. Default: 0.
Returns:
torch.Tensor: [seq_num, max_len] or [seq_num, pad_to]
"""
max_len = np.max(len_list) if pad_to is None else pad_to
min_len = np.min(len_list)
device = sequences[0].device
if max_len == min_len:
result = torch.cat(sequences, dim=0).view(-1, max_len)
else:
extra_len_list = np.subtract(max_len, len_list).tolist()
padding_nums = max_len * len(len_list) - np.sum(len_list)
padding_tensor = torch.tensor([-np.inf], device=device).repeat(padding_nums)
padding_list = torch.split(padding_tensor, extra_len_list)
result = list(itertools.chain.from_iterable(zip(sequences, padding_list)))
result = torch.cat(result)
return result.view(-1, max_len)
def trunc(scores, method):
"""Round the scores by using the given method
Args:
scores (numpy.ndarray): scores
method (str): one of ['ceil', 'floor', 'around']
Raises:
NotImplementedError: method error
Returns:
numpy.ndarray: processed scores
"""
try:
cut_method = getattr(np, method)
except NotImplementedError:
raise NotImplementedError("module 'numpy' has no function named '{}'".format(method))
scores = cut_method(scores)
return scores
def cutoff(scores, threshold):
"""cut of the scores based on threshold
Args:
scores (numpy.ndarray): scores
threshold (float): between 0 and 1
Returns:
numpy.ndarray: processed scores
"""
return np.where(scores > threshold, 1, 0)
def _binary_clf_curve(trues, preds):
"""Calculate true and false positives per binary classification threshold
Args:
trues (numpy.ndarray): the true scores' list
preds (numpy.ndarray): the predict scores' list
Returns:
fps (numpy.ndarray): A count of false positives, at index i being the number of negative
samples assigned a score >= thresholds[i]
preds (numpy.ndarray): An increasing count of true positives, at index i being the number
of positive samples assigned a score >= thresholds[i].
Note:
To improve efficiency, we referred to the source code(which is available at sklearn.metrics.roc_curve)
in SkLearn and made some optimizations.
"""
trues = (trues == 1)
desc_idxs = np.argsort(preds)[::-1]
preds = preds[desc_idxs]
trues = trues[desc_idxs]
unique_val_idxs = np.where(np.diff(preds))[0]
threshold_idxs = np.r_[unique_val_idxs, trues.size - 1]
tps = np.cumsum(trues)[threshold_idxs]
fps = 1 + threshold_idxs - tps
return fps, tps
# Copyright 2020 The HuggingFace Team. All rights reserved.
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
import torch
from torch import Tensor, nn
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: Tensor) -> Tensor:
return (
0.5
* input
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi)
* (input + 0.044715 * torch.pow(input, 3.0))
)
)
)
class GELUActivation(nn.Module):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def __init__(self, use_gelu_python: bool = False):
super().__init__()
if use_gelu_python:
self.act = self._gelu_python
else:
self.act = nn.functional.gelu
def _gelu_python(self, input: Tensor) -> Tensor:
return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
def forward(self, input: Tensor) -> Tensor:
return self.act(input)
class FastGELUActivation(nn.Module):
"""
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return (
0.5
* input
* (
1.0
+ torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))
)
)
class QuickGELUActivation(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input)
class ClippedGELUActivation(nn.Module):
"""
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
https://arxiv.org/abs/2004.09602.
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
"""
def __init__(self, min: float, max: float):
if min > max:
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
super().__init__()
self.min = min
self.max = max
def forward(self, x: Tensor) -> Tensor:
return torch.clip(gelu(x), self.min, self.max)
class SiLUActivation(nn.Module):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
def forward(self, input: Tensor) -> Tensor:
return nn.functional.silu(input)
class LinearActivation(nn.Module):
"""
Applies the linear activation function, i.e. forwarding input directly to output.
"""
def forward(self, input: Tensor) -> Tensor:
return input
class ClassInstantier(OrderedDict):
def __getitem__(self, key):
content = super().__getitem__(key)
cls, kwargs = content if isinstance(content, tuple) else (content, {})
return cls(**kwargs)
ACT2CLS = {
"gelu": GELUActivation,
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
"gelu_fast": FastGELUActivation,
"gelu_new": NewGELUActivation,
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
"linear": LinearActivation,
"quick_gelu": QuickGELUActivation,
"relu": nn.ReLU,
"relu6": nn.ReLU6,
"sigmoid": nn.Sigmoid,
"silu": SiLUActivation,
"swish": SiLUActivation,
"tanh": nn.Tanh,
}
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(
f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}"
)
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
linear_act = get_activation("linear")
# Copyright 2023 Baichuan Inc. All Rights Reserved.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class BaichuanConfig(PretrainedConfig):
model_type = "baichuan"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=125696,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
z_loss_weight=0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.z_loss_weight = z_loss_weight
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
This diff is collapsed.
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn.functional as F
from flash_attn.flash_attn_interface import (
flash_attn_qkvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
def flash_self_attention(
qkv: torch.Tensor,
causal: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
softmax_scale: Optional[float] = None,
attention_dropout: float = 0.0,
training: bool = False,
):
"""Implements the multihead softmax attention.
Modified from https://github.com/Dao-AILab/flash-attention/blob/v2.0.4/flash_attn/modules/mha.py#L35-L84
Arguments
---------
qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
causal: if passed, will override self.causal
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
unpadded = cu_seqlens is not None
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
assert isinstance(max_seqlen, int)
return flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
attention_dropout if training else 0.0,
softmax_scale=softmax_scale,
causal=causal,
)
else:
return flash_attn_qkvpacked_func(
qkv,
attention_dropout if training else 0.0,
softmax_scale=softmax_scale,
causal=causal,
)
def compute_flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
cu_input_lens: Optional[torch.Tensor] = None,
causal: bool = True,
training: bool = False,
attention_dropout: float = 0.0,
):
"""Modified from https://github.com/LAION-AI/Open-Assistant/blob/main/model/model_training/models/patching_utils.py"""
# q, k, v: [bs, seq_len, num_attention_heads, attn_head_size]
# attention_mask (float): [bs, seq_len]
batch_size, max_len = q.size(0), q.size(1)
qkv = torch.stack([q, k, v], dim=2) # [bs, seq_len, 3, num_attention_heads, attn_head_size]
if cu_input_lens is not None:
qkv.squeeze_(0)
cu_seqlens = F.pad(cu_input_lens.cumsum(dim=0, dtype=torch.int32), (1, 0))
max_seqlen = cu_input_lens.max().item()
out = flash_self_attention(
qkv,
causal=causal,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
training=training,
attention_dropout=attention_dropout,
)
return out
elif attention_mask is None:
return flash_self_attention(qkv, causal=causal, training=training, attention_dropout=attention_dropout)
else:
# Limitation: non-contiguous attention mask will not be handled correctly
# model will be able to pay attention between the first and last non-masked token, i.e. left- and right-side padding is supported.
cur_mask = attention_mask >= 0
csums = cur_mask.cumsum(dim=1, dtype=torch.int32)
ends = csums.argmax(dim=1) + 1
starts = ends - csums.max(dim=1).values
seqlens = ends - starts
# qkv = torch.cat([qkv[i, starts[i] : ends[i]] for i in range(batch_size)], dim=0)
qkv = qkv.view(batch_size*max_len, *qkv.size()[2:])
cur_mask = cur_mask.flatten().nonzero().squeeze()
qkv = qkv[cur_mask]
cu_seqlens = F.pad(seqlens.cumsum(dim=0, dtype=torch.int32), (1, 0))
max_seqlen = seqlens.max().item()
out = flash_self_attention(
qkv,
causal=causal,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
training=training,
attention_dropout=attention_dropout
)
# out: [num_unmasked_tokens, num_attention_heads, attn_head_size]
seqs = [out[start:end] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])]
# stack and pad sequences together
padded_seqs = [
F.pad(
seqs[i],
(0, 0) * (seqs[i].dim() - 1) + (starts[i], max_len - ends[i]),
value=0.0,
)
for i in range(batch_size)
]
out = torch.stack(padded_seqs)
return out
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliate
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np
import transformers
from transformers import AutoConfig, AutoModelForCausalLM
from logging import getLogger
from REC.utils.enum_type import InputType
from REC.model.basemodel import BaseModel, all_gather
from REC.model.HLLM.modeling_llama import LlamaForCausalLM
from REC.model.HLLM.modeling_mistral import MistralForCausalLM
from REC.model.HLLM.modeling_bert import BertModel
from REC.model.HLLM.baichuan.modeling_baichuan import BaichuanForCausalLM
class HLLM(BaseModel):
input_type = InputType.SEQ
def __init__(self, config, dataload):
super(HLLM, self).__init__()
self.logger = getLogger()
self.item_pretrain_dir = config['item_pretrain_dir']
self.user_pretrain_dir = config['user_pretrain_dir']
self.gradient_checkpointing = config['gradient_checkpointing']
self.use_ft_flash_attn = config['use_ft_flash_attn']
self.logger.info(f"create item llm")
self.item_llm = self.create_llm(self.item_pretrain_dir, config['item_llm_init'])
self.logger.info(f"create user llm")
self.user_llm = self.create_llm(self.user_pretrain_dir, config['user_llm_init'])
self.item_emb_token_n = config['item_emb_token_n']
if self.item_emb_token_n > 1:
raise NotImplementedError(f"Not support item_emb_token_n {self.item_emb_token_n} > 1")
if self.item_emb_token_n > 0:
self.item_emb_tokens = nn.Parameter(
torch.zeros(1, self.item_emb_token_n, self.item_llm.config.hidden_size)
)
self.item_emb_tokens.data.normal_(mean=0.0, std=0.02)
if config['item_emb_pretrain']:
ckpt = torch.load(config['item_emb_pretrain'], map_location='cpu')
self.logger.info(f"load item_emb_token from {config['item_emb_pretrain']} with {ckpt.size()}")
self.item_emb_tokens.data = nn.Parameter(ckpt)
else: # mean pooling
self.item_emb_tokens = None
self.loss = config['loss']
if self.loss == 'nce':
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.nce_thres = config['nce_thres'] if config['nce_thres'] else 0.99
self.num_negatives = config['num_negatives']
self.logger.info(f"nce thres setting to {self.nce_thres}")
else:
raise NotImplementedError(f"Only nce is supported")
if config['load_pretrain']:
state_dict = torch.load(config['load_pretrain'], map_location="cpu")
msg = self.load_state_dict(state_dict, strict=False)
self.logger.info(f"{msg.missing_keys = }")
self.logger.info(f"{msg.unexpected_keys = }")
def create_llm(self, pretrain_dir, init=True):
self.logger.info(f"******* create LLM {pretrain_dir} *******")
hf_config = AutoConfig.from_pretrained(pretrain_dir, trust_remote_code=True)
self.logger.info(f"hf_config: {hf_config}")
hf_config.gradient_checkpointing = self.gradient_checkpointing
hf_config.use_cache = False
hf_config.output_hidden_states = True
hf_config.return_dict = True
self.logger.info("xxxxx starting loading checkpoint")
if isinstance(hf_config, transformers.LlamaConfig):
hf_config.use_ft_flash_attn = self.use_ft_flash_attn
self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for llama')
self.logger.info(f'Init {init} for llama')
if init:
return LlamaForCausalLM.from_pretrained(pretrain_dir, config=hf_config)
else:
return LlamaForCausalLM(config=hf_config).cuda()
elif isinstance(hf_config, transformers.MistralConfig):
hf_config.use_ft_flash_attn = self.use_ft_flash_attn
self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for mistral')
self.logger.info(f'Init {init} for mistral')
if init:
return MistralForCausalLM.from_pretrained(pretrain_dir, config=hf_config)
else:
return MistralForCausalLM(config=hf_config).cuda()
elif isinstance(hf_config, transformers.BertConfig):
hf_config.use_ft_flash_attn = self.use_ft_flash_attn
self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for bert')
self.logger.info(f'Init {init} for bert')
if init:
return BertModel.from_pretrained(pretrain_dir, config=hf_config)
else:
return BertModel(config=hf_config).cuda()
elif getattr(hf_config, "model_type", None) == "baichuan":
hf_config.use_ft_flash_attn = self.use_ft_flash_attn
self.logger.info(f'Using flash attention {hf_config.use_ft_flash_attn} for baichuan')
self.logger.info(f'Init {init} for baichuan')
if init:
return BaichuanForCausalLM.from_pretrained(pretrain_dir, config=hf_config)
else:
return BaichuanForCausalLM(config=hf_config).cuda()
else:
return AutoModelForCausalLM.from_pretrained(
self.local_dir, config=hf_config
)
def nce_loss(self, cur_embs, target_pos, target_neg, user_attention_mask):
with torch.no_grad():
self.logit_scale.clamp_(0, np.log(100))
logit_scale = self.logit_scale.exp()
D = target_neg.size(-1)
output_embs = cur_embs / cur_embs.norm(dim=-1, keepdim=True)
target_pos_embs = target_pos / target_pos.norm(dim=-1, keepdim=True)
pos_logits = F.cosine_similarity(output_embs, target_pos_embs, dim=-1).unsqueeze(-1)
target_neg = target_neg / target_neg.norm(dim=-1, keepdim=True)
neg_embedding_all = all_gather(target_neg, sync_grads=True).reshape(-1, D) # [num, dim]
neg_embedding_all = neg_embedding_all.transpose(-1, -2)
neg_logits = torch.matmul(output_embs, neg_embedding_all)
fix_logits = torch.matmul(target_pos_embs, neg_embedding_all)
neg_logits[fix_logits > self.nce_thres] = torch.finfo(neg_logits.dtype).min
logits = torch.cat([pos_logits, neg_logits], dim=-1)
logits = logits[user_attention_mask.bool()] * logit_scale
labels = torch.zeros(logits.size(0), device=logits.device, dtype=torch.int64)
return logits, labels
def forward_item_emb(
self,
input_ids,
position_ids,
cu_input_lens,
emb_token_n,
emb_tokens,
llm
):
inputs_embeds = llm.get_input_embeddings()(input_ids)
emb_pos = cu_input_lens.cumsum(dim=0, dtype=torch.int32)
if emb_token_n > 0:
inputs_embeds[emb_pos - 1] = emb_tokens
model_out = llm(inputs_embeds=inputs_embeds.unsqueeze(0), cu_input_lens=cu_input_lens, position_ids=position_ids.unsqueeze(0))
model_out = model_out.hidden_states[-1].squeeze(0)
if emb_token_n > 0:
emb = model_out[emb_pos - 1]
else:
max_len = cu_input_lens.max().item()
cu_seqlens = F.pad(cu_input_lens.cumsum(dim=0, dtype=torch.int32), (1, 0))
seqs = [model_out[start:end] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])]
padded_seqs = [
F.pad(
seqs[i],
(0, 0) * (seqs[i].dim() - 1) + (0, max_len - cu_input_lens[i]),
value=0.0,
)
for i in range(cu_input_lens.size(0))
]
out = torch.stack(padded_seqs)
emb = out.sum(dim=1) / cu_input_lens.unsqueeze(1)
return emb
def forward(self, interaction, mode='train'):
if mode == 'predict':
return self.predict(interaction[0], interaction[1], interaction[2])
if mode == 'compute_item':
return self.compute_item(interaction)
user_attention_mask = interaction['attention_mask']
N, S = user_attention_mask.shape
pos_input_ids, pos_cu_input_lens, pos_position_ids = interaction['pos_input_ids'], interaction['pos_cu_input_lens'], interaction['pos_position_ids']
neg_input_ids, neg_cu_input_lens, neg_position_ids = interaction['neg_input_ids'], interaction['neg_cu_input_lens'], interaction['neg_position_ids']
pos_embedding = self.forward_item_emb(pos_input_ids, pos_position_ids, pos_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm)
pos_embedding = pos_embedding.reshape(N, S+1, -1)
neg_embedding = self.forward_item_emb(neg_input_ids, neg_position_ids, neg_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm)
neg_embedding = neg_embedding.reshape(N, -1, self.item_llm.config.hidden_size)
target_pos_embs = pos_embedding[:, 1:]
target_neg_embs = neg_embedding
user_embedding = self.user_llm(inputs_embeds=pos_embedding[:, :-1], attention_mask=user_attention_mask).hidden_states[-1]
model_out = {}
logits, labels = self.nce_loss(user_embedding, target_pos_embs, target_neg_embs, user_attention_mask)
model_out['loss'] = F.cross_entropy(logits, labels)
model_out['nce_samples'] = (logits > torch.finfo(logits.dtype).min/100).sum(dim=1).float().mean() # samples after filtering same negatives
for k in [1, 5, 10, 50, 100]:
if k > logits.size(1):
break
indices = logits.topk(k, dim=1).indices
model_out[f"nce_top{k}_acc"] = labels.view(-1, 1).eq(indices).any(dim=1).float().mean()
return model_out
@torch.no_grad()
def predict(self, item_seq, time_seq, item_feature):
attention_mask = (item_seq > 0).int()
pos_embedding = item_feature[item_seq]
user_embedding = self.user_llm(inputs_embeds=pos_embedding, attention_mask=attention_mask).hidden_states[-1]
seq_output = user_embedding[:, -1]
seq_output = seq_output / seq_output.norm(dim=-1, keepdim=True)
item_feature = item_feature / item_feature.norm(dim=-1, keepdim=True)
return torch.matmul(seq_output, item_feature.t())
@torch.no_grad()
def compute_item_all(self):
return self.item_embedding.weight
@torch.no_grad()
def compute_item(self, interaction):
pos_input_ids, pos_cu_input_lens, pos_position_ids = interaction['pos_input_ids'], interaction['pos_cu_input_lens'], interaction['pos_position_ids']
pos_embedding = self.forward_item_emb(pos_input_ids, pos_position_ids, pos_cu_input_lens, self.item_emb_token_n, self.item_emb_tokens, self.item_llm)
N = pos_cu_input_lens.size(0)
pos_embedding = pos_embedding.view(N, -1)
return pos_embedding
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment