"tests/vscode:/vscode.git/clone" did not exist on "97c8d94542b451814fdbf722eeabd9fff22796a4"
Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
import os
import torch
import logging
import torch.distributed as dist
from tqdm import tqdm
from dataclasses import asdict
from typing import Optional, List, Dict
from torch.utils.data import DataLoader, Dataset
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
from .metrics import RetrievalMetric
from ..utils.util import save_json
from transformers.trainer_utils import EvalLoopOutput
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
logger = logging.getLogger(__name__)
class RetrievalTrainer(Trainer):
def __init__(self, *args, corpus:Dataset, model_args, file_logger, **kwargs):
super().__init__(*args, **kwargs)
self.corpus = corpus
# handle save/load index/encoding/results
self.model_args = model_args
self.file_logger = file_logger
"""Trainer with retrieval-based evaluation."""
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
all_args = {
"model_args": asdict(self.model_args),
"training_args": asdict(self.args),
}
# Good practice: save your training arguments together with the trained model
save_json(all_args, os.path.join(output_dir, "args.json"))
@torch.no_grad()
def evaluate(self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "eval") -> Dict[str, float]:
# memory metrics - must set up as early as possible
self._memory_tracker.start()
if eval_dataset is None and self.eval_dataset is None:
return
args = self.args
self.model.eval()
# # make it to fp16
# dtype = self.model_args.dtype
# if dtype == "fp16":
# dtype = torch.float16
# else:
# dtype = torch.float32
# self.model.to(dtype)
# NOTE: very important to reset inbatch_same_dataset
inbatch_same_dataset = self.data_collator.inbatch_same_dataset
self.data_collator.inbatch_same_dataset = False
result_path = RetrievalMetric._get_save_path(self.model_args.eval_data, args.output_dir, field="result", save_name=self.model_args.save_name)
if self.model_args.load_result:
query_ids, preds, scores = RetrievalMetric._load_result(result_path)
else:
if args.eval_method == "retrieval":
# index corpus
self.model.index(
self.corpus,
output_dir=args.output_dir,
embedding_name=self.model_args.embedding_name,
index_factory=self.model_args.faiss_index_factory,
load_encode=self.model_args.load_encode,
save_encode=self.model_args.save_encode,
load_index=self.model_args.load_index,
save_index=self.model_args.save_index,
batch_size=self.model_args.batch_size,
)
# every process uses the same query because the corpus is sharded
dataloader = DataLoader(
self.eval_dataset,
batch_size=self.model_args.batch_size,
pin_memory=True,
collate_fn=self.data_collator,
)
query_ids = []
preds = [] # num_samples, hits
scores = []
for step, inputs in enumerate(tqdm(dataloader, desc="Searching")):
query_id = inputs.pop("query_id")
# the indices are already gathered, merged, and sorted inside search function
score, indice = self.model.search(inputs["query"], hits=self.model_args.hits) # batch_size, hits
query_ids.extend(query_id.tolist())
preds.extend(indice.tolist())
scores.extend(score.tolist())
elif args.eval_method == "rerank":
dataloader = DataLoader(
self.eval_dataset,
batch_size=self.model_args.batch_size,
pin_memory=True,
collate_fn=self.data_collator,
)
dataloader = self.accelerator.prepare(dataloader)
query_ids = []
preds = [] # num_samples, hits
scores = []
for step, inputs in enumerate(tqdm(dataloader, desc="Ranking")):
inputs = self._prepare_inputs(inputs)
query_id = inputs.pop("query_id")
key_index = inputs.pop("key_index") # batch_size, key_num
score, indice = self.model.rerank(**inputs, hits=self.model_args.hits) # batch_size, hits
# NOTE: when the indices of the keys (w.r.t. the corpus) are provided, we should rerank these indices instead of returning the raw indices
# NOTE: when using gather, the index must bigger than -1!
gather_index = indice.clone()
gather_index[indice == -1] = 0
new_indice = key_index.gather(index=gather_index, dim=-1)
# NOTE: mask the padded candidate
indice = new_indice.masked_fill(indice == -1, -1)
query_id = self.accelerator.gather_for_metrics(query_id)
# NOTE: important to pad here for later gathering, because different devices may have different key number
# FIXME: dim cannot be -1
indice = self.accelerator.pad_across_processes(indice, pad_index=-1, dim=1)
score = self.accelerator.pad_across_processes(score, pad_index=torch.finfo(score.dtype).min, dim=1)
pred = self.accelerator.gather_for_metrics(indice.contiguous())
score = self.accelerator.gather_for_metrics(score.contiguous())
query_ids.extend(query_id.tolist())
preds.extend(pred.tolist())
scores.extend(score.tolist())
# if step > 4:
# break
else:
raise NotImplementedError(f"Eval method {args.eval_method} not implemented!")
if args.process_index == 0 and self.model_args.save_result:
RetrievalMetric._save_result(query_ids, preds, result_path, scores=scores)
if args.process_index == 0:
metrics = [self.compute_metrics(query_ids, preds, scores=scores)]
else:
metrics = [None]
# NOTE: broadcast across devices
dist.broadcast_object_list(metrics, src=0)
metrics = metrics[0]
self.accelerator.wait_for_everyone()
# reset
self.data_collator.inbatch_same_dataset = inbatch_same_dataset
# self.model.to(torch.float32)
# Prefix all keys with metric_key_prefix + '_'
for key in list(metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_") and key != "epoch":
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
output = EvalLoopOutput(predictions=preds, metrics=metrics, label_ids=None, num_samples=len(preds))
self.log(output.metrics)
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
self._memory_tracker.stop_and_update_metrics(output.metrics)
# log to file
if args.process_index == 0:
self.file_logger.log(
metrics=metrics,
Model_Args=asdict(self.model_args),
Training_Args=asdict(args),
Global_Steps=self.state.global_step
)
return output.metrics
class EarlyExitCallBack(TrainerCallback):
def __init__(self, early_exit_steps=None):
self.early_exit_steps = early_exit_steps
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if self.early_exit_steps is not None and state.global_step > self.early_exit_steps:
control.should_training_stop = True
from .util import FileLogger, Sequential_Sampler, DatasetProcessFn, DefaultDataCollator, makedirs, split_file_dir_name_ext, clear_dir, get_max_length_in_nested_lists, pad_nested_lists, mask_nested_lists, are_elements_of_same_length, normalize_text, load_json, save_json, load_pickle, save_pickle, add_eos, remove_eos
\ No newline at end of file
from typing import List, Optional, Tuple
import torch
import types
import warnings
import importlib
import transformers
import logging
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, LlamaPreTrainedModel
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
from einops import rearrange
from peft.tuners.lora import LoraLayer
logger = logging.getLogger(__name__)
# ADAPTED from https://github.com/allenai/open-instruct/blob/main/open_instruct/llama_flash_attn_monkey_patch.py
# AND https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
# AND https://github.com/LAION-AI/Open-Assistant/blob/04fa9a24b2a58c8885b8aa6a2eb02b18de6b4961/model/model_training/models/patching_llama.py
# AND Sourabh https://github.com/huggingface/transformers/commit/ee81bf5aee0d65f005d157c013777e3d27d8d6bf
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# Past Key value support
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# [bsz, seq_len]
return attention_mask
def enable_flash_attention(model=None):
if model is not None and not isinstance(model, LlamaPreTrainedModel):
logger.warning(f"flash attention not implemented for model {type(model)}!")
return
logger.warning("reloading llama model, enabling flash attention...")
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
print(
"Flash attention is only supported on Ampere or Hopper GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
if model is None:
# override class, instantiate later
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
else:
# override model, already instatiated
if hasattr(model, "lm_head"):
model = model.model
model._prepare_decoder_attention_mask = types.MethodType(_prepare_decoder_attention_mask, model)
for layer in model.layers:
layer.self_attn.forward = types.MethodType(forward, layer.self_attn)
def disable_flash_attention(model=None):
if model is not None and not isinstance(model, LlamaPreTrainedModel):
logger.warning(f"flash attention not implemented for model {type(model)}!")
return
logger.warning("reloading llama model, disabling flash attention...")
if model is None:
# override class, instantiate later
importlib.reload(transformers.models.llama.modeling_llama)
else:
# override model, already instatiated
forward = transformers.models.llama.modeling_llama.LlamaAttention.forward
_prepare_decoder_attention_mask = transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask
if hasattr(model, "lm_head"):
model = model.model
model._prepare_decoder_attention_mask = types.MethodType(_prepare_decoder_attention_mask, model)
for layer in model.layers:
layer.self_attn.forward = types.MethodType(forward, layer.self_attn)
# Adapted from https://github.com/tmm1/axolotl/blob/2eda9e02a9d15a7a3f92b41f257d9844d72fc220/src/axolotl/utils/models.py#L338
def upcast_layer_for_flash_attention(model, torch_dtype):
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
# convert them back to fp16/bf16 for flash-attn compatibility.
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
module.to(torch_dtype)
if "norm" in name:
module.to(torch_dtype)
if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"):
module.to(torch_dtype)
return model
import os
import sys
import pytz
import json
import torch
import shutil
import pathlib
import time
import pickle
import logging
import string
import numpy as np
import pandas as pd
from contextlib import contextmanager
from dataclasses import dataclass
from transformers.tokenization_utils import PreTrainedTokenizer
from datetime import datetime
from collections import defaultdict, OrderedDict
from typing import Optional, Tuple, Union, List, Callable, Dict, Any, Mapping
logger = logging.getLogger(__name__)
@contextmanager
def do_nothing():
yield
def makedirs(path):
p = pathlib.Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
return path
def clear_dir(directory):
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
def split_file_dir_name_ext(path):
"""Return the directory, name, and extension of a given file."""
p = pathlib.Path(path)
assert p.is_file()
return p.parent, p.stem, p.suffix
def save_pickle(obj, path:str):
"""
Save pickle file.
"""
if not os.path.exists(path):
makedirs(path)
with open(path, "wb") as f:
return pickle.dump(obj, f)
def load_pickle(path):
with open(path, "rb") as f:
return pickle.load(f)
def save_json(obj, path:str):
if not os.path.exists(path):
makedirs(path)
with open(path, "w") as f:
return json.dump(obj, f, ensure_ascii=False)
def load_json(path, lines=False):
if lines:
output = []
with open(path, "r") as f:
for line in f:
output.append(json.loads(line))
return output
else:
with open(path, "r") as f:
return json.load(f)
@contextmanager
def filelock(path, process_index=0):
while os.path.exists(path):
if i == 0 and process_index == 0:
logger.info("found lock, waiting for other programs...")
time.sleep(3)
i = 1
if process_index == 0:
save_json("this is a lock", path)
yield
if process_index == 0:
os.remove(path)
def normalize_text(text, ignore_case=True, ignore_punctuation=True, ignore_space=True, ignore_number=False):
if isinstance(text, str):
text = [text]
unpack = True
else:
unpack = False
if ignore_case:
text = np.char.lower(text)
if ignore_punctuation:
repl_table = string.punctuation.maketrans("", "", string.punctuation)
text = np.char.translate(text, table=repl_table)
if ignore_number:
repl_table = string.digits.maketrans("", "", string.digits)
text = np.char.translate(text, table=repl_table)
if ignore_space:
for i, words in enumerate(np.char.split(text)):
text[i] = " ".join(words)
if isinstance(text, np.ndarray):
text = text.tolist()
if unpack:
text = text[0]
return text
def min_max_normalize(array):
return (array - array.min(-1)[:,None])/(array.max(-1) - array.min(-1))[:, None]
def get_max_length_in_nested_lists(lst):
if len(lst) and isinstance(lst[0], list):
lengths = []
for elem in lst:
length = get_max_length_in_nested_lists(elem)
lengths.append(length)
max_length = max(lengths)
return max_length
else:
return len(lst)
def pad_nested_lists(lst, max_length, padding_value, padding_side="right"):
if isinstance(lst, list) and len(lst) and isinstance(lst[0], list):
masks = []
for i, elem in enumerate(lst):
lst[i], mask = pad_nested_lists(elem, max_length, padding_value, padding_side)
masks.append(mask)
return lst, masks
elif isinstance(lst, list):
if padding_side == "right":
mask = [1] * len(lst) + [0] * (max_length - len(lst))
lst = lst + [padding_value for _ in range(max_length - len(lst))]
return lst, mask
else:
mask = [0] * (max_length - len(lst)) + [1] * len(lst)
lst = [padding_value for _ in range(max_length - len(lst))] + lst
return lst, mask
else:
raise NotImplementedError(f"Unrecognized type {lst}")
def mask_nested_lists(lst, mask_target, mask_value=0):
if isinstance(lst[0], list):
for i, elem in enumerate(lst):
lst[i] = mask_nested_lists(elem, mask_target, mask_value)
return lst
else:
return [x if x != mask_target else mask_value for x in lst]
def are_elements_of_same_length(lst: List):
if not isinstance(lst[0], list):
return False
length = len(lst[0])
return all(len(x) == length if isinstance(x, list) else False for x in lst)
def add_eos(inputs: Mapping, eos_token_id: int):
for k, v in inputs.items():
assert isinstance(v, list), f"Make sure the return_tensors are set to list!"
if k == "input_ids":
v = v + [eos_token_id]
elif k == "position_ids":
v = v + [v[-1] + 1]
elif k in ["attention_mask", "token_type_ids"]:
v = v + v[-1:]
else:
raise NotImplementedError(f"Inputs key {k} not implemented!")
inputs[k] = v
return inputs
def remove_eos(inputs: Mapping, eos_token_id: int):
input_ids = inputs["input_ids"]
eos_idx = [i for i, x in enumerate(input_ids) if x == eos_token_id][0]
for k, v in inputs.items():
inputs[k].pop(eos_idx)
return inputs
def mix_parameters(models: List[torch.nn.Module], weights: Optional[List[float]]=None):
"""Mix parameters of different models according to given weights.
Returns:
the model with mixed parameters.
"""
new_state_dict = OrderedDict()
if weights is None:
weights = [1 / len(models) for _ in range(len(models))]
else:
assert len(weights) == len(models), f"Make sure the size of mix weights equals to the number of models!"
for name_param_pairs in zip(*[model.state_dict().items() for model in models]):
names = [name_param_pair[0] for name_param_pair in name_param_pairs]
params = [name_param_pair[1] for name_param_pair in name_param_pairs]
assert all(name == names[0] for name in names), f"Found incompatible key in {names}!"
name = names[0]
mixed_param = None
# there may be non-float parameters stored, which should not be mixed
if params[0].dtype not in [torch.float16, torch.bfloat16, torch.float32]:
assert all((param == params[0]).all() for param in params), f"Found incompatible value in non-float tensor {params}!"
new_state_dict[name] = params[0]
continue
for weight, param in zip(weights, params):
if mixed_param is None:
mixed_param = weight * param
else:
mixed_param += weight * param
new_state_dict[name] = mixed_param
model = models[0]
info = model.load_state_dict(new_state_dict)
print(info)
return model
class FileLogger:
def __init__(self, log_file) -> None:
self.log_file = log_file
def log(self, metrics, **kwargs):
with open(self.log_file, "a+") as f:
# get current time
tz = pytz.timezone('Asia/Shanghai')
time = f"{'Time': <10}: {json.dumps(datetime.now(tz).strftime('%Y-%m-%d, %H:%M:%S'), ensure_ascii=False)}\n"
command = f"{'Command': <10}: {json.dumps(' '.join(sys.argv), ensure_ascii=False)}\n"
metrics = f"{'Metrics': <10}: {json.dumps(metrics, ensure_ascii=False)}\n"
msg = time + command
print(msg + metrics)
for key, value in kwargs.items():
try:
msg += f"{key: <10}: {json.dumps(value, ensure_ascii=False)}\n"
except:
print(key)
print(value)
raise
msg += metrics
f.write(str(msg) + "\n")
class Sequential_Sampler:
"""
The sampler used in creating sequential dataloader.
"""
def __init__(self, dataset_length:int, num_replicas:int, rank:int) -> None:
"""
Args:
dataset_length: length of the dataset
num_replicas: number of splits
rank: the current process id
Attributes:
start: the starting index
end: the ending index
"""
super().__init__()
len_per_worker = dataset_length / num_replicas
# force to set rank==0 because when world_size==1 the local_rank is -1 by default
if num_replicas == 1:
rank = 0
self.start = round(len_per_worker * rank)
self.end = round(len_per_worker * (rank + 1))
self.rank = rank
def __iter__(self):
start = self.start
end = self.end
return iter(range(start, end, 1))
def __len__(self):
return self.end - self.start
class DatasetProcessFn:
"""Wrapper for any user-defined process function for huggingface datasets.
1. Process batched examples by looping the process function over them;
2. Gather returned examples if any data augmentation happens with augment=True;
3. Pass indices of examples inside the process function with _index keywords if they exist.
The wrapped function should take in any needed columns and return a dict with 1 or more samples.
"""
def __init__(self, augment=False):
self.augment = augment
def __call__(self, _process_fn):
def process(*args):
sample_or_batch_sample = args[0]
if len(args) == 1:
pass
elif len(args) == 2:
indices = args[1]
# detach the slice so that _index will not be set in the original data
sample_or_batch_sample = sample_or_batch_sample.copy()
sample_or_batch_sample["_index"] = indices
else:
raise NotImplementedError(f"Found more than 2 arguments {args}!")
keys = list(sample_or_batch_sample.keys())
func_args = [sample_or_batch_sample[k] for k in keys]
# FIXME: if all values in one sample are of the same length, this would fail
if are_elements_of_same_length(func_args):
outputs = defaultdict(list)
for arg in zip(*func_args):
# get each element in a batch
kwargs = {keys[j]: arg[j] for j in range(len(arg))}
output = _process_fn(**kwargs)
if output is not None:
for k, v in output.items():
if self.augment:
outputs[k].extend(v)
else:
outputs[k].append(v)
else:
outputs = _process_fn(**sample_or_batch_sample)
if outputs is None:
raise ValueError(f"Found None returned from process_fn. Make sure you set 'batched=True' when trying to augment/distract samples in the datasets!")
return dict(outputs)
return process
@dataclass
class DefaultDataCollator:
"""
Data collator that can:
1. Dynamically pad all inputs received. The inputs must be dict of lists.
2. Add position_ids based on attention_mask if required.
"""
tokenizer: PreTrainedTokenizer
attention_padding_value: int = 0
label_padding_value: int = -100
add_position_ids: bool = False
def __call__(self, batch_elem: List) -> Dict[str, Any]:
first_elem = batch_elem[0]
return_batch = {}
for key, value in first_elem.items():
# HACK: any key containing attention_mask must be attention_mask
# important to assign different pad token for different types of inputs
if "attention_mask" in key:
pad_token_id = self.attention_padding_value
elif "label" in key:
pad_token_id = self.label_padding_value
else:
pad_token_id = self.tokenizer.pad_token_id
batch_value = [elem[key] for elem in batch_elem]
# pad all lists and nested lists
if isinstance(value, list):
max_length = get_max_length_in_nested_lists(batch_value)
batch_value, _ = pad_nested_lists(batch_value, max_length, pad_token_id, self.tokenizer.padding_side)
return_batch[key] = torch.tensor(batch_value)
if "attention_mask" in key and self.add_position_ids:
value = return_batch[key]
position_ids = value.cumsum(-1) - 1
position_ids = position_ids.masked_fill(value == 0, 0)
return_batch[key.replace("attention_mask", "position_ids")] = position_ids
return return_batch
# Reranker
与embedding模型不同,reranker使用问题和文档作为输入,直接输出相关性而不是embedding。
您可以输入查询语句和段落到reranker后直接得到相关性得分。并且分数可以通过sigmoid函数映射到[0,1]中的浮点值。
## Model List
- [bge-reranker-base](http://113.200.138.88:18080/aimodels/bge-reranker-base)
- [bge-reranker-large](http://113.200.138.88:18080/aimodels/bge-reranker-large)
您可以根据个人场景和资源来选择所需模型:
- 针对 **多种语言**,使用 [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)[BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma)
- 针对 **中文或者英文**, 使用 [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)[BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise).
- 针对 **效率**, 使用 [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) 和 底层[BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise).
- 想要更好的效果, 建议 [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise)[BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma)
## Usage
### 使用 FlagEmbedding
确认环境配置完成,请参考[环境配置](../../README.md#环境配置)
#### 常规 reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
计算相关性得分(得分越高,相关性越强):
```python
from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
score = reranker.compute_score(['query', 'passage'])
print(score) # -5.65234375
# You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
score = reranker.compute_score(['query', 'passage'], normalize=True)
print(score) # 0.003497010252573502
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores) # [-8.1875, 5.26171875]
# You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], normalize=True)
print(scores) # [0.00027803096387751553, 0.9948403768236574]
```
#### 针对 LLM-based reranker
```python
from FlagEmbedding import FlagLLMReranker
reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation
score = reranker.compute_score(['query', 'passage'])
print(score)
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores)
```
#### 针对 LLM-based layerwise reranker
```python
from FlagEmbedding import LayerWiseFlagLLMReranker
reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation
score = reranker.compute_score(['query', 'passage'], cutoff_layers=[28]) # Adjusting 'cutoff_layers' to pick which layers are used for computing the score.
print(score)
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], cutoff_layers=[28])
print(scores)
```
### 使用 Huggingface transformers
#### 常规 reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
Get relevance scores (higher scores indicate more relevance):
```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-v2-m3')
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
print(scores)
```
#### 针对 LLM-based reranker
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
if prompt is None:
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
sep = "\n"
prompt_inputs = tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep_inputs = tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
inputs = []
for query, passage in pairs:
query_inputs = tokenizer(f'A: {query}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True)
passage_inputs = tokenizer(f'B: {passage}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True)
item = tokenizer.prepare_for_model(
[tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
inputs.append(item)
return tokenizer.pad(
inputs,
padding=True,
max_length=max_length + len(sep_inputs) + len(prompt_inputs),
pad_to_multiple_of=8,
return_tensors='pt',
)
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-gemma')
model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-gemma')
yes_loc = tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = get_inputs(pairs, tokenizer)
scores = model(**inputs, return_dict=True).logits[:, -1, yes_loc].view(-1, ).float()
print(scores)
```
#### 针对 LLM-based layerwise reranker
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
if prompt is None:
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
sep = "\n"
prompt_inputs = tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep_inputs = tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
inputs = []
for query, passage in pairs:
query_inputs = tokenizer(f'A: {query}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True)
passage_inputs = tokenizer(f'B: {passage}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True)
item = tokenizer.prepare_for_model(
[tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
inputs.append(item)
return tokenizer.pad(
inputs,
padding=True,
max_length=max_length + len(sep_inputs) + len(prompt_inputs),
pad_to_multiple_of=8,
return_tensors='pt',
)
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to('cuda')
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = get_inputs(pairs, tokenizer).to(model.device)
all_scores = model(**inputs, return_dict=True, cutoff_layers=[28])
all_scores = [scores[:, -1].view(-1, ).float() for scores in all_scores[0]]
print(all_scores)
```
## 微调
### 数据格式
训练数据是一个json文件,其中每一行都是这样的字典:
```
{"query": str, "pos": List[str], "neg":List[str], "prompt": str}
```
`query` 是查询语句, `pos` 是正文本list, `neg` 是负文本list,`prompt`说明查询与文本的关系。
如果针对查询语句没有负文本,你可以随机从整个语料库中选取样本作为负样本,如[toy_finetune_data.jsonl](../../examples/finetune/toy_finetune_data.jsonl)
### Train
您可以跟随下面的步骤训练 reranker:
**常规 reranker** (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
参考: ../../examples/reranker
**针对 llm-based reranker** (bge-reranker-v2-gemma)
```shell
torchrun --nproc_per_node {number of gpus} \
-m FlagEmbedding.llm_reranker.finetune_for_instruction.run \
--output_dir {path to save model} \
--model_name_or_path google/gemma-2b \
--train_data ./toy_finetune_data.jsonl \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--dataloader_drop_last True \
--query_max_len 512 \
--passage_max_len 512 \
--train_group_size 16 \
--logging_steps 1 \
--save_steps 2000 \
--save_total_limit 50 \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--deepspeed stage1.json \
--warmup_ratio 0.1 \
--bf16 \
--use_lora True \
--lora_rank 32 \
--lora_alpha 64 \
--use_flash_attn True \
--target_modules q_proj k_proj v_proj o_proj
```
**针对 llm-based layerwise reranker** (bge-reranker-v2-minicpm-layerwise)
```shell
torchrun --nproc_per_node {number of gpus} \
-m FlagEmbedding.llm_reranker.finetune_for_layerwise.run \
--output_dir {path to save model} \
--model_name_or_path openbmb/MiniCPM-2B-dpo-bf16 \
--train_data ./toy_finetune_data.jsonl \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--dataloader_drop_last True \
--query_max_len 512 \
--passage_max_len 512 \
--train_group_size 16 \
--logging_steps 1 \
--save_steps 2000 \
--save_total_limit 50 \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--deepspeed stage1.json \
--warmup_ratio 0.1 \
--bf16 \
--use_lora True \
--lora_rank 32 \
--lora_alpha 64 \
--use_flash_attn True \
--target_modules q_proj k_proj v_proj o_proj \
--start_layer 8 \
--head_multi True \
--head_type simple \
--lora_extra_parameters linear_head \
--finetune_type from_raw_model # should be one of ['from_raw_model', 'from_finetuned_model']
```
rerankers 通过 [google/gemma-2b](https://huggingface.co/google/gemma-2b) (针对 llm-based reranker) 和 [openbmb/MiniCPM-2B-dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16) (针对 llm-based layerwise reranker) 进行初始化,使用混合多语言数据集进行训练。
- [bge-m3-data](https://huggingface.co/datasets/Shitao/bge-m3-data)
- [quora train data](https://huggingface.co/datasets/quora)
- [fever train data](https://fever.ai/dataset/fever.html)
### 融合模型
微调之后,需要进行模型融合。
**针对 llm-based reranker**
```python
from FlagEmbedding.llm_reranker.merge import merge_llm
merge_llm('google/gemma-2b', 'lora_llm_output_path', 'merged_model_output_paths')
```
**针对 llm-based layerwise reranker**
如果基于原始模型进行的微调(openbmb/MiniCPM-2B-dpo-bf16)
```shell
from FlagEmbedding.llm_reranker.merge import merge_layerwise_raw_llm
merge_layerwise_raw_llm('openbmb/MiniCPM-2B-dpo-bf16', 'lora_llm_output_path', 'merged_model_output_paths')
```
如果基于微调模型进行的微调(BAAI/bge-reranker-v2-minicpm-layerwise)
```shell
from FlagEmbedding.llm_reranker.merge import merge_layerwise_finetuned_llm
merge_layerwise_finetuned_llm('BAAI/bge-reranker-v2-minicpm-layerwise', 'lora_llm_output_path', 'merged_model_output_paths')
```
# Reranker
- [Model List](#model-list)
- [Usage](#usage)
- [Fine-tuning](#fine-tune)
- [Evaluation](#evaluation)
- [Citation](#citation)
Different from embedding model, reranker uses question and document as input and directly output similarity instead of embedding.
You can get a relevance score by inputting query and passage to the reranker.
And the score can be mapped to a float value in [0,1] by sigmoid function.
## Model List
| Model | Base model | Language | layerwise | feature |
|:--------------------------------------------------------------------------|:--------:|:-----------------------------------------------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
| [BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base) | [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) | Chinese and English | - | Lightweight reranker model, easy to deploy, with fast inference. |
| [BAAI/bge-reranker-large](https://huggingface.co/BAAI/bge-reranker-large) | [xlm-roberta-large](https://huggingface.co/FacebookAI/xlm-roberta-large) | Chinese and English | - | Lightweight reranker model, easy to deploy, with fast inference. |
| [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) | [bge-m3](https://huggingface.co/BAAI/bge-m3) | Multilingual | - | Lightweight reranker model, possesses strong multilingual capabilities, easy to deploy, with fast inference. |
| [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma) | [gemma-2b](https://huggingface.co/google/gemma-2b) | Multilingual | - | Suitable for multilingual contexts, performs well in both English proficiency and multilingual capabilities. |
| [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) | [MiniCPM-2B-dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16) | Multilingual | 8-40 | Suitable for multilingual contexts, performs well in both English and Chinese proficiency, allows freedom to select layers for output, facilitating accelerated inference. |
You can select the model according your senario and resource.
- For **multilingual**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma)
- For **Chinese or English**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise).
- For **efficiency**, utilize [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) and the low layer of [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise).
- For better performance, recommand [BAAI/bge-reranker-v2-minicpm-layerwise](https://huggingface.co/BAAI/bge-reranker-v2-minicpm-layerwise) and [BAAI/bge-reranker-v2-gemma](https://huggingface.co/BAAI/bge-reranker-v2-gemma)
## Usage
### Using FlagEmbedding
```
pip install -U FlagEmbedding
```
#### For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
Get relevance scores (higher scores indicate more relevance):
```python
from FlagEmbedding import FlagReranker
reranker = FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
score = reranker.compute_score(['query', 'passage'])
print(score) # -5.65234375
# You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
score = reranker.compute_score(['query', 'passage'], normalize=True)
print(score) # 0.003497010252573502
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores) # [-8.1875, 5.26171875]
# You can map the scores into 0-1 by set "normalize=True", which will apply sigmoid function to the score
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], normalize=True)
print(scores) # [0.00027803096387751553, 0.9948403768236574]
```
#### For LLM-based reranker
```python
from FlagEmbedding import FlagLLMReranker
reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# reranker = FlagLLMReranker('BAAI/bge-reranker-v2-gemma', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation
score = reranker.compute_score(['query', 'passage'])
print(score)
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])
print(scores)
```
#### For LLM-based layerwise reranker
```python
from FlagEmbedding import LayerWiseFlagLLMReranker
reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
# reranker = LayerWiseFlagLLMReranker('BAAI/bge-reranker-v2-minicpm-layerwise', use_bf16=True) # You can also set use_bf16=True to speed up computation with a slight performance degradation
score = reranker.compute_score(['query', 'passage'], cutoff_layers=[28]) # Adjusting 'cutoff_layers' to pick which layers are used for computing the score.
print(score)
scores = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']], cutoff_layers=[28])
print(scores)
```
### Using Huggingface transformers
#### For normal reranker (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
Get relevance scores (higher scores indicate more relevance):
```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-m3')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-v2-m3')
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
print(scores)
```
#### For LLM-based reranker
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
if prompt is None:
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
sep = "\n"
prompt_inputs = tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep_inputs = tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
inputs = []
for query, passage in pairs:
query_inputs = tokenizer(f'A: {query}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True)
passage_inputs = tokenizer(f'B: {passage}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True)
item = tokenizer.prepare_for_model(
[tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
inputs.append(item)
return tokenizer.pad(
inputs,
padding=True,
max_length=max_length + len(sep_inputs) + len(prompt_inputs),
pad_to_multiple_of=8,
return_tensors='pt',
)
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-gemma')
model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-gemma')
yes_loc = tokenizer('Yes', add_special_tokens=False)['input_ids'][0]
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = get_inputs(pairs, tokenizer)
scores = model(**inputs, return_dict=True).logits[:, -1, yes_loc].view(-1, ).float()
print(scores)
```
#### For LLM-based layerwise reranker
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_inputs(pairs, tokenizer, prompt=None, max_length=1024):
if prompt is None:
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
sep = "\n"
prompt_inputs = tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids']
sep_inputs = tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
inputs = []
for query, passage in pairs:
query_inputs = tokenizer(f'A: {query}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True)
passage_inputs = tokenizer(f'B: {passage}',
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True)
item = tokenizer.prepare_for_model(
[tokenizer.bos_token_id] + query_inputs['input_ids'],
sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
item['input_ids'] = item['input_ids'] + sep_inputs + prompt_inputs
item['attention_mask'] = [1] * len(item['input_ids'])
inputs.append(item)
return tokenizer.pad(
inputs,
padding=True,
max_length=max_length + len(sep_inputs) + len(prompt_inputs),
pad_to_multiple_of=8,
return_tensors='pt',
)
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('BAAI/bge-reranker-v2-minicpm-layerwise', trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to('cuda')
model.eval()
pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']]
with torch.no_grad():
inputs = get_inputs(pairs, tokenizer).to(model.device)
all_scores = model(**inputs, return_dict=True, cutoff_layers=[28])
all_scores = [scores[:, -1].view(-1, ).float() for scores in all_scores[0]]
print(all_scores)
```
## Fine-tune
### Data Format
Train data should be a json file, where each line is a dict like this:
```
{"query": str, "pos": List[str], "neg":List[str], "prompt": str}
```
`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts, `prompt` indicates the relationship between query and texts. If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives.
See [toy_finetune_data.jsonl](https://github.com/FlagOpen/FlagEmbedding/tree/master/FlagEmbedding/llm_reranker/toy_finetune_data.jsonl) for a toy data file.
### Train
You can fine-tune the reranker with the following code:
**For normal reranker** (bge-reranker-base / bge-reranker-large / bge-reranker-v2-m3 )
Refer to: https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/reranker
**For llm-based reranker** (bge-reranker-v2-gemma)
```shell
torchrun --nproc_per_node {number of gpus} \
-m FlagEmbedding.llm_reranker.finetune_for_instruction.run \
--output_dir {path to save model} \
--model_name_or_path google/gemma-2b \
--train_data ./toy_finetune_data.jsonl \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--dataloader_drop_last True \
--query_max_len 512 \
--passage_max_len 512 \
--train_group_size 16 \
--logging_steps 1 \
--save_steps 2000 \
--save_total_limit 50 \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--deepspeed stage1.json \
--warmup_ratio 0.1 \
--bf16 \
--use_lora True \
--lora_rank 32 \
--lora_alpha 64 \
--use_flash_attn True \
--target_modules q_proj k_proj v_proj o_proj
```
**For llm-based layerwise reranker** (bge-reranker-v2-minicpm-layerwise)
```shell
torchrun --nproc_per_node {number of gpus} \
-m FlagEmbedding.llm_reranker.finetune_for_layerwise.run \
--output_dir {path to save model} \
--model_name_or_path openbmb/MiniCPM-2B-dpo-bf16 \
--train_data ./toy_finetune_data.jsonl \
--learning_rate 2e-4 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--dataloader_drop_last True \
--query_max_len 512 \
--passage_max_len 512 \
--train_group_size 16 \
--logging_steps 1 \
--save_steps 2000 \
--save_total_limit 50 \
--ddp_find_unused_parameters False \
--gradient_checkpointing \
--deepspeed stage1.json \
--warmup_ratio 0.1 \
--bf16 \
--use_lora True \
--lora_rank 32 \
--lora_alpha 64 \
--use_flash_attn True \
--target_modules q_proj k_proj v_proj o_proj \
--start_layer 8 \
--head_multi True \
--head_type simple \
--lora_extra_parameters linear_head \
--finetune_type from_raw_model # should be one of ['from_raw_model', 'from_finetuned_model']
```
Our rerankers are initialized from [google/gemma-2b](https://huggingface.co/google/gemma-2b) (for llm-based reranker) and [openbmb/MiniCPM-2B-dpo-bf16](https://huggingface.co/openbmb/MiniCPM-2B-dpo-bf16) (for llm-based layerwise reranker), and we train it on a mixture of multilingual datasets:
- [bge-m3-data](https://huggingface.co/datasets/Shitao/bge-m3-data)
- [quora train data](https://huggingface.co/datasets/quora)
- [fever train data](https://fever.ai/dataset/fever.html)
### Merge Model
After finetune, you need to merge the model
**For llm-based reranker**
```python
from FlagEmbedding.llm_reranker.merge import merge_llm
merge_llm('google/gemma-2b', 'lora_llm_output_path', 'merged_model_output_paths')
```
**For llm-based layerwise reranker**
If you finetune the raw model (openbmb/MiniCPM-2B-dpo-bf16)
```shell
from FlagEmbedding.llm_reranker.merge import merge_layerwise_raw_llm
merge_layerwise_raw_llm('openbmb/MiniCPM-2B-dpo-bf16', 'lora_llm_output_path', 'merged_model_output_paths')
```
If you finetune the finetuned model (BAAI/bge-reranker-v2-minicpm-layerwise)
```shell
from FlagEmbedding.llm_reranker.merge import merge_layerwise_finetuned_llm
merge_layerwise_finetuned_llm('BAAI/bge-reranker-v2-minicpm-layerwise', 'lora_llm_output_path', 'merged_model_output_paths')
```
## Evaluation
- llama-index.
![image-20240317193909373](./evaluation/llama-index.png)
- BEIR.
rerank the top 100 results from bge-en-v1.5 large.
![image-20240319140555921](./evaluation/BEIR-bge-en-v1.5.png)
rerank the top 100 results from e5 mistral 7b instruct.
![image-20240317172949713](./evaluation/BEIR-e5-mistral.png)
- CMTEB-retrieval.
It rerank the top 100 results from bge-zh-v1.5 large.
![image-20240317173026235](./evaluation/CMTEB-retrieval-bge-zh-v1.5.png)
- miracl (multi-language).
It rerank the top 100 results from bge-m3.
![image-20240317173117639](./evaluation/miracl-bge-m3.png)
## Citation
If you find this repository useful, please consider giving a star :star: and citation
```
@misc{li2023making,
title={Making Large Language Models A Better Foundation For Dense Retrieval},
author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
year={2023},
eprint={2312.15503},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@misc{chen2024bge,
title={BGE M3-Embedding: Multi-Lingual, Multi-Functionality, Multi-Granularity Text Embeddings Through Self-Knowledge Distillation},
author={Jianlv Chen and Shitao Xiao and Peitian Zhang and Kun Luo and Defu Lian and Zheng Liu},
year={2024},
eprint={2402.03216},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
import os
from dataclasses import dataclass, field
from typing import Optional, List
from transformers import TrainingArguments
def default_list() -> List[str]:
return ["q_proj", "v_proj", "o_proj", "down_proj", "up_proj", "gate_proj"]
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
peft_model_path: str = field(
default=''
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_lora: bool = field(
default=True,
metadata={"help": "If passed, will use LORA (low-rank parameter-efficient training) to train the model."}
)
lora_rank: int = field(
default=64,
metadata={"help": "The rank of lora."}
)
lora_alpha: float = field(
default=16,
metadata={"help": "The alpha parameter of lora."}
)
lora_dropout: float = field(
default=0.1,
metadata={"help": "The dropout rate of lora modules."}
)
target_modules: List[str] = field(
default_factory=default_list
)
save_merged_lora_model: bool = field(
default=False,
metadata={"help": "If passed, will merge the lora modules and save the entire model."}
)
use_flash_attn: bool = field(
default=True,
metadata={"help": "If passed, will use flash attention to train the model."}
)
use_slow_tokenizer: bool = field(
default=False,
metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."}
)
low_cpu_mem_usage: bool = field(
default=False,
metadata={"help": "It is an option to create the model as an empty shell,"
"then only materialize its parameters when the pretrained weights are loaded."
"If passed, LLM loading time and RAM consumption will be benefited."}
)
cache_dir: str = field(
default="tmp", metadata={"help": "the cache of the model"}
)
token: str = field(
default=None, metadata={"help": "the token to access the huggingface model"}
)
from_peft: str = field(
default=None
)
lora_extra_parameters: str = field(
default=None
)
@dataclass
class DataArguments:
train_data: str = field(
default='toy_finetune_data.jsonl', metadata={"help": "Path to train data"}
)
train_group_size: int = field(default=8)
query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
passage_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_example_num_per_dataset: int = field(
default=100000000, metadata={"help": "the max number of examples for each dataset"}
)
query_instruction_for_retrieval: str = field(
default="A: ", metadata={"help": "query: "}
)
passage_instruction_for_retrieval: str = field(
default="B: ", metadata={"help": "passage: "}
)
cache_path: str = field(
default='./data_dir'
)
load_from_disk: bool = field(
default=False, metadata={"help": " whether load the data from disk"}
)
load_disk_path: str = field(
default=None, metadata={"help": " the path to load the data", "nargs": "+"}
)
save_to_disk: bool = field(
default=False, metadata={"help": " whether save the data to disk"}
)
save_disk_path: str = field(
default=None, metadata={"help": " the path to save the data"}
)
num_shards: int = field(
default=0, metadata={
"help": "number of shards to write, prior than `save_max_shard_size`, default depends on `save_max_shard_size`"}
)
save_max_shard_size: str = field(
default="50GB", metadata={"help": "the max size of the shard"}
)
exit_after_save: bool = field(
default=False, metadata={"help": " whether exit after save the data"}
)
def __post_init__(self):
if not os.path.exists(self.train_data):
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")
@dataclass
class RetrieverTrainingArguments(TrainingArguments):
loss_type: str = field(default='only logits')
import re
import sys
from typing import List
import math
import os.path
import random
from dataclasses import dataclass
import datasets
import numpy as np
from torch.utils.data import Dataset
from transformers import DataCollatorForSeq2Seq
from transformers import PreTrainedTokenizer, BatchEncoding
from .arguments import DataArguments
class TrainDatasetForReranker(Dataset):
def __init__(
self,
args: DataArguments,
tokenizer: PreTrainedTokenizer
):
if os.path.isdir(args.train_data):
train_datasets = []
for file in os.listdir(args.train_data):
try:
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file),
split='train',
cache_dir=args.cache_path)
except Exception as e:
print(e)
print(file)
sys.exit()
if len(temp_dataset) > args.max_example_num_per_dataset:
temp_dataset = temp_dataset.select(
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets)
else:
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train', cache_dir=args.cache_path)
self.tokenizer = tokenizer
self.args = args
self.total_len = len(self.dataset)
sep = "\n"
self.sep_inputs = self.tokenizer(sep,
return_tensors=None,
add_special_tokens=False)['input_ids']
self.max_length = self.args.query_max_len + self.args.passage_max_len
def __len__(self):
return self.total_len
def is_chinese(self, text):
chinese_pattern = re.compile('[\u4e00-\u9fa5]')
return bool(chinese_pattern.search(text))
def __getitem__(self, item) -> List[BatchEncoding]:
query = self.dataset[item]['query']
passages = []
pos = random.choice(self.dataset[item]['pos'])
passages.append(pos)
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
else:
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
passages.extend(negs)
prompt = self.dataset[item]['prompt']
query = f'{self.args.query_instruction_for_retrieval}{query}'
passages = [f'{self.args.passage_instruction_for_retrieval}{p}' for p in passages]
query_inputs = self.tokenizer(query,
return_tensors=None,
max_length=self.args.query_max_len + self.args.passage_max_len // 4,
truncation=True,
add_special_tokens=False)
positive_inputs = self.tokenizer(prompt,
return_tensors=None,
add_special_tokens=False)['input_ids'] + \
self.tokenizer('Yes',
return_tensors=None,
add_special_tokens=False)['input_ids']
max_length = self.max_length - len(positive_inputs) - len(self.sep_inputs)
passages_inputs = []
for i, passage in enumerate(passages):
passage_inputs = self.tokenizer(passage,
return_tensors=None,
max_length=self.args.passage_max_len + self.args.query_max_len // 2,
truncation=True,
add_special_tokens=False)
if self.tokenizer.bos_token_id is not None and self.tokenizer.bos_token_id != self.tokenizer.pad_token_id:
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
else:
item = self.tokenizer.prepare_for_model(
query_inputs['input_ids'],
self.sep_inputs + passage_inputs['input_ids'],
truncation='only_second',
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False
)
passage_inputs['input_ids'] = item['input_ids'] + self.sep_inputs + positive_inputs
passage_inputs['attention_mask'] = [1] * len(passage_inputs['input_ids'])
passage_inputs['labels'] = passage_inputs['input_ids'].copy()
passage_inputs['labels'] = [-100] * (len(passage_inputs['input_ids']) - 1) + passage_inputs['labels'][(len(passage_inputs['input_ids']) - 1):]
passage_inputs.pop('token_type_ids') if 'token_type_ids' in passage_inputs.keys() else None
if 'position_ids' in passage_inputs.keys():
passage_inputs['position_ids'] = list(range(len(passage_inputs['input_ids'])))
passages_inputs.append(passage_inputs)
return passages_inputs
@dataclass
class RerankCollator(DataCollatorForSeq2Seq):
"""
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
and pass batch separately to the actual collator.
Abstract out data detail for the model.
"""
query_max_len: int = 32
passage_max_len: int = 128
def __call__(self, features, return_tensors='pt'):
if return_tensors is None:
return_tensors = self.return_tensors
if isinstance(features[0], list):
features = sum(features, [])
# print(features)
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if labels is not None:
max_label_length = max(len(l) for l in labels)
# print(max_label_length)
if self.pad_to_multiple_of is not None:
max_label_length = (
(max_label_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
if isinstance(feature["labels"], list):
feature["labels"] = (
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
)
elif padding_side == "right":
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
else:
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
collated = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.query_max_len + self.passage_max_len,
return_tensors=return_tensors,
pad_to_multiple_of=self.pad_to_multiple_of,
)
return {"pair": collated}
# return collated
\ No newline at end of file
import torch
from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
def get_model(model_args, training_args):
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
use_flash_attention_2=True if model_args.use_flash_attn else False,
token=model_args.token,
cache_dir=model_args.cache_dir,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
trust_remote_code=True,
)
model.config.use_cache = False
if model_args.from_peft is not None:
model = PeftModel.from_pretrained(model, model_args.from_peft, is_trainable=True)
model.print_trainable_parameters()
else:
if model_args.use_lora:
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=model_args.lora_rank,
target_modules=model_args.target_modules,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
modules_to_save=model_args.lora_extra_parameters
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(model)
return model
\ No newline at end of file
import logging
from dataclasses import dataclass
from typing import Dict, Optional, List, Union
import torch
from torch import nn, Tensor
from transformers import AutoTokenizer
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
@dataclass
class RerankerOutput(ModelOutput):
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiEncoderModel(nn.Module):
def __init__(self,
model: None,
tokenizer: AutoTokenizer = None,
train_batch_size: int = 4,
):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
if self.model.config.pad_token_id is None:
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.config = self.model.config
self.train_batch_size = train_batch_size
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def enable_input_require_grads(self, **kwargs):
self.model.enable_input_require_grads(**kwargs)
def encode(self, features):
# input('continue?')
if features is None:
return None
outputs = self.model(input_ids=features['input_ids'],
attention_mask=features['attention_mask'],
position_ids=features['position_ids'] if 'position_ids' in features.keys() else None,
output_hidden_states=True)
_, max_indices = torch.max(features['labels'], dim=1)
predict_indices = max_indices - 1
logits = [outputs.logits[i, predict_indices[i], :] for i in range(outputs.logits.shape[0])]
logits = torch.stack(logits, dim=0)
scores = logits[:, self.yes_loc]
return scores.contiguous()
def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None):
ranker_logits = self.encode(pair) # (batch_size * num, dim)
if self.training:
grouped_logits = ranker_logits.view(self.train_batch_size, -1)
target = torch.zeros(self.train_batch_size, device=grouped_logits.device, dtype=torch.long)
loss = self.compute_loss(grouped_logits, target)
else:
loss = None
# print(loss)
return RerankerOutput(
loss=loss,
scores=ranker_logits,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def save(self, output_dir: str):
# self.model.save_pretrained(output_dir)
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)
def save_pretrained(self, **kwargs):
self.tokenizer.save_pretrained(**kwargs)
return self.model.save_pretrained(**kwargs)
import logging
import os
from pathlib import Path
from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import TrainDatasetForReranker, RerankCollator
from .modeling import BiEncoderModel
from .trainer import BiTrainer
from .load_model import get_model
logger = logging.getLogger(__name__)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
# Set seed
set_seed(training_args.seed)
num_labels = 1
base_model = get_model(model_args, training_args)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
trust_remote_code=True,
token=model_args.token,
add_eos_token=True
)
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token_id = tokenizer.unk_token_id
elif tokenizer.eod_id is not None:
tokenizer.pad_token_id = tokenizer.eod_id
tokenizer.bos_token_id = tokenizer.im_start_id
tokenizer.eos_token_id = tokenizer.im_end_id
if 'mistral' in model_args.model_name_or_path.lower():
tokenizer.padding_side = 'left'
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=model_args.cache_dir,
trust_remote_code=True,
)
logger.info('Config: %s', config)
model = BiEncoderModel(model=base_model,
tokenizer=tokenizer,
train_batch_size=training_args.per_device_train_batch_size)
# model = base_model
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
train_dataset = TrainDatasetForReranker(args=data_args, tokenizer=tokenizer)
trainer = BiTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=RerankCollator(
tokenizer=tokenizer,
query_max_len=data_args.query_max_len,
passage_max_len=data_args.passage_max_len,
pad_to_multiple_of=8,
return_tensors="pt",
padding=True
),
tokenizer=tokenizer,
)
trainer.use_lora = model_args.use_lora
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
# Training
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if not model_args.use_lora:
checkpoint_dir = os.path.join(training_args.output_dir, "checkpoint-final")
trainer.deepspeed.save_checkpoint(checkpoint_dir)
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()
from transformers.trainer import *
from transformers.deepspeed import is_deepspeed_zero3_enabled
from peft import get_peft_model_state_dict
class BiTrainer(Trainer):
use_lora: bool
def _save(self, output_dir: Optional[str] = None, state_dict=None):
if not self.use_lora:
super()._save(output_dir, state_dict)
return
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not hasattr(self.model, 'save'):
raise NotImplementedError(
f'MODEL {self.model.__class__.__name__} '
f'does not support save interface')
else:
self.model.save(output_dir)
# if self.tokenizer is not None and self.is_world_process_zero():
# self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
if is_deepspeed_zero3_enabled():
if state_dict is None:
state_dict = self.model.state_dict()
prefix = 'model.'
assert all(k.startswith(prefix) for k in state_dict.keys()), list(state_dict.keys())
state_dict = {k[len(prefix):]: v for k, v in state_dict.items()}
lora_state_dict = get_peft_model_state_dict(self.model.model, state_dict)
if self.args.process_index <= 0:
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
print(f"Save adapter model at {output_dir}")
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
outputs = model(**inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment