Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
import logging
from dataclasses import dataclass
from typing import Dict, Optional
import os
import torch
import torch.distributed as dist
from torch import nn, Tensor
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from transformers.file_utils import ModelOutput
from huggingface_hub import snapshot_download
logger = logging.getLogger(__name__)
@dataclass
class EncoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BGEM3Model(nn.Module):
def __init__(self,
model_name: str = None,
normlized: bool = True,
sentence_pooling_method: str = 'cls',
negatives_cross_device: bool = False,
temperature: float = 1.0,
enable_sub_batch: bool = True,
unified_finetuning: bool = True,
use_self_distill: bool = False,
colbert_dim: int = -1,
self_distill_start_step: int = -1,
):
super().__init__()
self.load_model(model_name, colbert_dim=colbert_dim)
self.vocab_size = self.model.config.vocab_size
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.unified_finetuning = unified_finetuning
if not self.unified_finetuning:
self.colbert_linear = None
self.sparse_linear = None
self.normlized = normlized
self.sentence_pooling_method = sentence_pooling_method
self.enable_sub_batch = enable_sub_batch
self.temperature = temperature
self.use_self_distill = use_self_distill
self.self_distill_start_step = self_distill_start_step
self.step = 0
if not normlized:
self.temperature = 1.0
logger.info("reset temperature = 1.0 due to using inner product to compute similarity")
self.negatives_cross_device = negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError('Distributed training has not been initialized for representation all gather.')
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
def load_model(self, model_name, colbert_dim: int = -1):
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
self.model = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.colbert_linear = torch.nn.Linear(in_features=self.model.config.hidden_size,
out_features=self.model.config.hidden_size if colbert_dim == -1 else colbert_dim)
self.sparse_linear = torch.nn.Linear(in_features=self.model.config.hidden_size, out_features=1)
if os.path.exists(os.path.join(model_name, 'colbert_linear.pt')) and os.path.exists(
os.path.join(model_name, 'sparse_linear.pt')):
logger.info('loading existing colbert_linear and sparse_linear---------')
self.load_pooler(model_dir=model_name)
else:
logger.info(
'The parameters of colbert_linear and sparse linear is new initialize. Make sure the model is loaded for training, not inferencing')
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def dense_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == 'cls':
return hidden_state[:, 0]
elif self.sentence_pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
token_weights = torch.relu(self.sparse_linear(hidden_state))
if not return_embedding: return token_weights
sparse_embedding = torch.zeros(input_ids.size(0), input_ids.size(1), self.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device)
sparse_embedding = torch.scatter(sparse_embedding, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights)
unused_tokens = [self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
self.tokenizer.unk_token_id]
sparse_embedding = torch.max(sparse_embedding, dim=1).values
sparse_embedding[:, unused_tokens] *= 0.
return sparse_embedding
def colbert_embedding(self, last_hidden_state, mask):
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
return colbert_vecs
def dense_score(self, q_reps, p_reps):
scores = self.compute_similarity(q_reps, p_reps) / self.temperature
scores = scores.view(q_reps.size(0), -1)
return scores
def sparse_score(self, q_reps, p_reps):
scores = self.compute_similarity(q_reps, p_reps) / self.temperature
scores = scores.view(q_reps.size(0), -1)
return scores
def colbert_score(self, q_reps, p_reps, q_mask: torch.Tensor):
token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)
scores, _ = token_scores.max(-1)
scores = scores.sum(1) / q_mask[:, 1:].sum(-1, keepdim=True)
scores = scores / self.temperature
return scores
def _encode(self, features):
dense_vecs, sparse_vecs, colbert_vecs = None, None, None
last_hidden_state = self.model(**features, return_dict=True).last_hidden_state
dense_vecs = self.dense_embedding(last_hidden_state, features['attention_mask'])
if self.unified_finetuning:
sparse_vecs = self.sparse_embedding(last_hidden_state, features['input_ids'])
colbert_vecs = self.colbert_embedding(last_hidden_state, features['attention_mask'])
if self.normlized:
dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
if self.unified_finetuning:
colbert_vecs = torch.nn.functional.normalize(colbert_vecs, dim=-1)
return dense_vecs, sparse_vecs, colbert_vecs
def encode(self, features, sub_batch_size=None):
if features is None:
return None
if sub_batch_size is not None and sub_batch_size != -1:
all_dense_vecs, all_sparse_vecs, all_colbert_vecs = [], [], []
for i in range(0, len(features['attention_mask']), sub_batch_size):
end_inx = min(i + sub_batch_size, len(features['attention_mask']))
sub_features = {}
for k, v in features.items():
sub_features[k] = v[i:end_inx]
dense_vecs, sparse_vecs, colbert_vecs = self._encode(sub_features)
all_dense_vecs.append(dense_vecs)
all_sparse_vecs.append(sparse_vecs)
all_colbert_vecs.append(colbert_vecs)
dense_vecs = torch.cat(all_dense_vecs, 0)
if self.unified_finetuning:
sparse_vecs = torch.cat(all_sparse_vecs, 0)
colbert_vecs = torch.cat(all_colbert_vecs, 0)
else:
dense_vecs, sparse_vecs, colbert_vecs = self._encode(features)
if self.unified_finetuning:
return dense_vecs.contiguous(), sparse_vecs.contiguous(), colbert_vecs.contiguous()
else:
return dense_vecs.contiguous(), None, None
def compute_sub_batch_size(self, features):
mapping = [(6000, 1), (5000, 2), (4000, 3), (3000, 3), (2000, 5), (1000, 9), (512, 16), (0, 32)]
cur_l = features['input_ids'].size(-1)
for l, b in mapping:
if cur_l >= l:
return b
def compute_similarity(self, q_reps, p_reps):
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def distill_loss(self, teacher_targets, student_scores, group_size):
labels = torch.arange(student_scores.size(0), device=student_scores.device, dtype=torch.long)
labels = labels * group_size
loss = 0
mask = torch.zeros_like(student_scores)
for i in range(group_size):
temp_target = labels + i
temp_scores = student_scores + mask
temp_loss = F.cross_entropy(temp_scores, temp_target, reduction="none") # B
loss += torch.mean(teacher_targets[:, i] * temp_loss)
mask = torch.scatter(mask, dim=-1, index=temp_target.unsqueeze(-1),
value=torch.finfo(student_scores.dtype).min)
return loss
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_scores: Tensor = None,
bi_directions=None):
if self.enable_sub_batch:
q_dense_vecs, q_sparse_vecs, q_colbert_vecs = self.encode(query,
sub_batch_size=self.compute_sub_batch_size(query))
p_dense_vecs, p_sparse_vecs, p_colbert_vecs = self.encode(passage,
sub_batch_size=self.compute_sub_batch_size(
passage))
else:
q_dense_vecs, q_sparse_vecs, q_colbert_vecs = self.encode(query)
p_dense_vecs, p_sparse_vecs, p_colbert_vecs = self.encode(passage)
if self.training:
if teacher_scores is not None:
# print("Use soft-label distillation...")
teacher_targets = F.softmax(teacher_scores, dim=-1) # B N
group_size = p_sparse_vecs.size(0) // q_sparse_vecs.size(0)
# dense loss
dense_scores = self.dense_score(q_dense_vecs, p_dense_vecs) # B, B * N
if self.negatives_cross_device:
cross_q_dense_vecs = self._dist_gather_tensor(q_dense_vecs)
cross_p_dense_vecs = self._dist_gather_tensor(p_dense_vecs)
cross_teacher_targets = self._dist_gather_tensor(teacher_targets)
cross_dense_scores = self.dense_score(cross_q_dense_vecs, cross_p_dense_vecs)
loss = self.distill_loss(cross_teacher_targets, cross_dense_scores, group_size=group_size)
else:
loss = self.distill_loss(teacher_targets, dense_scores, group_size=group_size)
if self.unified_finetuning:
# sparse and colbert loss
sparse_scores = self.sparse_score(q_sparse_vecs, p_sparse_vecs) # B, B * N
sparse_loss = self.distill_loss(teacher_targets, sparse_scores, group_size=group_size)
colbert_scores = self.colbert_score(q_colbert_vecs, p_colbert_vecs,
q_mask=query['attention_mask']) # B, B * N
colbert_loss = self.distill_loss(teacher_targets, colbert_scores, group_size=group_size)
ensemble_loss = self.distill_loss(teacher_targets,
dense_scores + 0.3 * sparse_scores + colbert_scores,
group_size=group_size)
loss = (loss + ensemble_loss + 0.1 * sparse_loss + colbert_loss) / 4
else:
idxs = torch.arange(q_dense_vecs.size(0), device=q_dense_vecs.device, dtype=torch.long)
targets = idxs * (p_dense_vecs.size(0) // q_dense_vecs.size(0))
# dense loss
dense_scores = self.dense_score(q_dense_vecs, p_dense_vecs) # B, B * N
if self.negatives_cross_device:
cross_q_dense_vecs = self._dist_gather_tensor(q_dense_vecs)
cross_p_dense_vecs = self._dist_gather_tensor(p_dense_vecs)
cross_idxs = torch.arange(cross_q_dense_vecs.size(0), device=cross_q_dense_vecs.device, dtype=torch.long)
cross_targets = cross_idxs * (cross_p_dense_vecs.size(0) // cross_q_dense_vecs.size(0))
cross_dense_scores = self.dense_score(cross_q_dense_vecs, cross_p_dense_vecs)
loss = self.compute_loss(cross_dense_scores, cross_targets)
else:
loss = self.compute_loss(dense_scores, targets)
if self.unified_finetuning:
# sparse and colbert loss
sparse_scores = self.sparse_score(q_sparse_vecs, p_sparse_vecs) # B, B * N
sparse_loss = self.compute_loss(sparse_scores, targets)
colbert_scores = self.colbert_score(q_colbert_vecs, p_colbert_vecs,
q_mask=query['attention_mask']) # B, B * N
colbert_loss = self.compute_loss(colbert_scores, targets)
ensemble_loss = self.compute_loss(dense_scores + 0.3 * sparse_scores + colbert_scores, targets)
loss = (loss + ensemble_loss + 0.1 * sparse_loss + colbert_loss) / 4
if self.use_self_distill and self.step > self.self_distill_start_step and self.unified_finetuning:
ensemble_scores = dense_scores + 0.3 * sparse_scores + colbert_scores
teacher_targets = torch.softmax(ensemble_scores.detach(), dim=-1)
ensemble_distill_dense_loss = - torch.mean(
torch.sum(torch.log_softmax(dense_scores, dim=-1) * teacher_targets, dim=-1))
ensemble_distill_sparse_loss = - torch.mean(
torch.sum(torch.log_softmax(sparse_scores, dim=-1) * teacher_targets, dim=-1))
ensemble_distill_colbert_loss = - torch.mean(
torch.sum(torch.log_softmax(colbert_scores, dim=-1) * teacher_targets, dim=-1))
loss += (ensemble_distill_dense_loss + 0.1 * ensemble_distill_sparse_loss + ensemble_distill_colbert_loss) / 3
loss = loss / 2
self.step += 1
else:
loss = None
return EncoderOutput(
loss=loss,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def save(self, output_dir: str):
def _trans_state_dict(state_dict):
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
return state_dict
self.model.save_pretrained(output_dir, state_dict=_trans_state_dict(self.model.state_dict()))
if self.unified_finetuning:
torch.save(_trans_state_dict(self.colbert_linear.state_dict()),
os.path.join(output_dir, 'colbert_linear.pt'))
torch.save(_trans_state_dict(self.sparse_linear.state_dict()),
os.path.join(output_dir, 'sparse_linear.pt'))
def load_pooler(self, model_dir):
colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')
self.colbert_linear.load_state_dict(colbert_state_dict)
self.sparse_linear.load_state_dict(sparse_state_dict)
class BGEM3ForInference(BGEM3Model):
def forward(self,
text_input: Dict[str, Tensor] = None,
return_dense: bool = True,
return_sparse: bool = False,
return_colbert: bool = False,
return_sparse_embedding: bool = False):
assert return_dense or return_sparse or return_colbert, 'Must choose one or more from `return_colbert`, `return_sparse`, `return_dense` to set `True`!'
last_hidden_state = self.model(**text_input, return_dict=True).last_hidden_state
output = {}
if return_dense:
dense_vecs = self.dense_embedding(last_hidden_state, text_input['attention_mask'])
output['dense_vecs'] = dense_vecs
if return_sparse:
sparse_vecs = self.sparse_embedding(last_hidden_state, text_input['input_ids'],
return_embedding=return_sparse_embedding)
output['sparse_vecs'] = sparse_vecs
if return_colbert:
colbert_vecs = self.colbert_embedding(last_hidden_state, text_input['attention_mask'])
output['colbert_vecs'] = colbert_vecs
if self.normlized:
if 'dense_vecs' in output:
output['dense_vecs'] = torch.nn.functional.normalize(output['dense_vecs'], dim=-1)
if 'colbert_vecs' in output:
output['colbert_vecs'] = torch.nn.functional.normalize(output['colbert_vecs'], dim=-1)
return output
import logging
import os
from pathlib import Path
import torch.distributed as dist
from transformers import AutoConfig, AutoTokenizer
from transformers import (
HfArgumentParser,
set_seed,
)
from transformers import (
TrainerCallback,
TrainingArguments,
TrainerState,
TrainerControl
)
from .arguments import ModelArguments, DataArguments, \
RetrieverTrainingArguments as TrainingArguments
from .data import SameDatasetTrainDataset, EmbedCollator
from .modeling import BGEM3Model
from .trainer import BiTrainer
logger = logging.getLogger(__name__)
class TrainerCallbackForDataRefresh(TrainerCallback):
def __init__(self, train_dataset):
self.train_dataset = train_dataset
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of an epoch.
"""
self.train_dataset.refresh_epoch()
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)
logger.info("Model parameters %s", model_args)
logger.info("Data parameters %s", data_args)
# Set seed
set_seed(training_args.seed)
num_labels = 1
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=False,
)
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
cache_dir=model_args.cache_dir,
)
logger.info('Config: %s', config)
model = BGEM3Model(model_name=model_args.model_name_or_path,
normlized=training_args.normlized,
sentence_pooling_method=training_args.sentence_pooling_method,
negatives_cross_device=training_args.negatives_cross_device,
temperature=training_args.temperature,
enable_sub_batch=training_args.enable_sub_batch,
unified_finetuning=training_args.unified_finetuning,
use_self_distill=training_args.use_self_distill,
colbert_dim=training_args.colbert_dim,
self_distill_start_step=training_args.self_distill_start_step)
if training_args.fix_position_embedding:
for k, v in model.named_parameters():
if "position_embeddings" in k:
logging.info(f"Freeze the parameters for {k}")
v.requires_grad = False
if training_args.fix_encoder:
for k, v in model.named_parameters():
if "colbert_linear" in k or 'sparse_linear' in k:
logging.info(f"train the parameters for {k}")
else:
v.requires_grad = False
# print(f"===========================Rank {dist.get_rank()}: start loading data===========================")
if data_args.same_task_within_batch:
train_dataset = SameDatasetTrainDataset(args=data_args,
batch_size=training_args.per_device_train_batch_size,
seed=training_args.seed,
num_processes=training_args.world_size,
process_index=training_args.process_index)
training_args.per_device_train_batch_size = 1
training_args.dataloader_num_workers = 0 # avoid multi-processes
else:
raise NotImplementedError("Not support `same_task_within_batch=False`")
data_collator = EmbedCollator(
tokenizer,
query_max_len=data_args.query_max_len,
passage_max_len=data_args.passage_max_len
)
trainer = BiTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
tokenizer=tokenizer
)
if data_args.same_task_within_batch:
trainer.add_callback(TrainerCallbackForDataRefresh(train_dataset))
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
# Training
# print(f"===========================Rank {dist.get_rank()}: start training===========================")
trainer.train()
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()
"""
python split_data_by_length.py \
--input_path train_data \
--output_dir train_data_split \
--cache_dir .cache \
--log_name .split_log \
--length_list 0 500 1000 2000 3000 4000 5000 6000 7000 \
--model_name_or_path BAAI/bge-m3 \
--num_proc 16 \
--overwrite False
"""
import os
import json
import math
import time
import argparse
import datasets
from tqdm import tqdm
from pprint import pprint
from transformers import AutoTokenizer
from datasets import load_dataset, Features, Value, Sequence
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', type=str, required=True, help='the path of input datas')
parser.add_argument('--output_dir', type=str, required=True, help='the dir of output datas')
parser.add_argument('--cache_dir', type=str, default=None, help='the cache dir')
parser.add_argument('--log_name', type=str, default='.split_log', help='the name of log file, default: `.split_log`, which will be saved to `output_dir`')
parser.add_argument('--length_list', type=int, default=[0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000], nargs='+', help='the length list to split')
parser.add_argument('--model_name_or_path', type=str, default='BAAI/bge-m3', help='the model name or path of the tokenizer')
parser.add_argument('--num_proc', type=int, default=16, help='the number of process, default: 16')
parser.add_argument('--overwrite', action='store_true', default=False, help='whether to overwrite the output file, default: False')
args = parser.parse_args()
return args
class SplitByLengthHandler:
def __init__(self,
model_name_or_path: str,
cache_dir: str=None,
num_proc: int=16,
length_list: list=[0, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000],
overwrite: bool=False):
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.cache_dir = cache_dir
self.num_proc = num_proc
self.length_ranges_list = self._get_length_ranges_list(length_list)
self.overwrite = overwrite
pprint(self.length_ranges_list)
def _map_func(examples):
results = {}
results['idx'] = []
results['max_length'] = []
for i in range(len(examples['query'])):
idx = examples['idx'][i]
query = examples['query'][i]
pos, neg = examples['pos'][i], examples['neg'][i]
all_texts = [query] + pos + neg
max_len = 0
for x in all_texts:
tokenized_x = self.tokenizer(x)['input_ids']
if len(tokenized_x) > max_len:
max_len = len(tokenized_x)
results['idx'].append(idx)
results['max_length'].append(max_len)
return results
self._map_func = _map_func
@staticmethod
def _get_length_ranges_list(length_list: list):
length_ranges_list = []
length_list = sorted(length_list)
for i in range(len(length_list)):
length_l = length_list[i]
if i == len(length_list) - 1:
length_r = math.inf
else:
length_r = length_list[i + 1]
assert 0 <= length_l < length_r
length_ranges_list.append((length_l, length_r))
return length_ranges_list
def _process_dir(self, dir_path: str, output_dir: str):
assert os.path.isdir(dir_path)
log_info_list = []
for file in tqdm(os.listdir(dir_path), desc=f'processing {dir_path}'):
file_path = os.path.join(dir_path, file)
if not file_path.endswith('.jsonl'):
print(f"skip {file_path} ...")
continue
output_path = os.path.join(output_dir, '.'.join(file.split('.')[:-1]))
log_info = self._process_file(file_path, output_path)
log_info_list.append(log_info)
return log_info_list
def _process_file(self, file_path: str, output_path: str):
assert not os.path.isdir(file_path)
start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
features = Features({
'query': Value('string'),
'pos': Sequence(Value('string')),
'neg': Sequence(Value('string'))
})
kd_features = Features({
'query': Value('string'),
'pos': Sequence(Value('string')),
'neg': Sequence(Value('string')),
'pos_scores': Sequence(Value('float')),
'neg_scores': Sequence(Value('float'))
})
try:
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=features)['train']
except:
dataset = load_dataset('json', data_files=file_path, cache_dir=self.cache_dir, features=kd_features)['train']
dataset_with_idx_list = []
for i, data in enumerate(dataset):
data['idx'] = i
dataset_with_idx_list.append(data)
dataset_with_idx = datasets.Dataset.from_list(dataset_with_idx_list)
mapped_dataset = dataset_with_idx.map(self._map_func, batched=True, num_proc=self.num_proc)
split_info_dict = {}
for length_l, length_r in self.length_ranges_list:
save_path = output_path + f'_len-{length_l}-{length_r}.jsonl'
if os.path.exists(save_path) and not self.overwrite:
print(f'{save_path} exists, skip')
continue
idxs = mapped_dataset.filter(lambda x: length_l <= x['max_length'] < length_r, num_proc=self.num_proc)
split_dataset = dataset_with_idx.select(idxs['idx'])
split_dataset = split_dataset.remove_columns('idx')
split_info_dict[f'len-{length_l}-{length_r}'] = len(split_dataset)
if len(split_dataset) > 0:
split_dataset.to_json(save_path, force_ascii=False)
end_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
size = len(dataset)
avg_length = sum(mapped_dataset['max_length']) / size
log_info = {
'file_name': os.path.basename(file_path),
'size': size,
'avg_length': avg_length,
'file_path': file_path,
'start_time': start_time,
'end_time': end_time,
'split_info': split_info_dict
}
return log_info
def run(self, input_path: str, output_dir: str, log_name: str=None):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if log_name is None:
log_path = os.path.join(output_dir, '.split_log')
else:
log_path = os.path.join(output_dir, log_name)
log_info_list = []
if os.path.isdir(input_path):
log_info_list = self._process_dir(input_path, output_dir)
else:
file_name = os.path.basename(input_path)
output_path = os.path.join(output_dir, '.'.join(file_name.split('.')[:-1]))
log_info = self._process_file(input_path, output_path)
log_info_list.append(log_info)
with open(log_path, 'a', encoding='utf-8') as f:
for log_info in log_info_list:
json.dump(log_info, f, ensure_ascii=False)
f.write('\n')
if __name__ == '__main__':
args = get_args()
input_path = args.input_path
output_dir = args.output_dir
log_name = args.log_name
handler = SplitByLengthHandler(
model_name_or_path=args.model_name_or_path,
cache_dir=args.cache_dir,
num_proc=args.num_proc,
length_list=args.length_list if isinstance(args.length_list, list) else [args.length_list],
overwrite=args.overwrite
)
handler.run(
input_path=input_path,
output_dir=output_dir,
log_name=log_name
)
print('\nDONE!')
from sentence_transformers import SentenceTransformer, models
from transformers.trainer import *
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normlized: bool=True):
word_embedding_model = models.Transformer(ckpt_dir)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode=pooling_mode)
if normlized:
normlize_layer = models.Normalize()
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normlize_layer], device='cpu')
else:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
model.save(ckpt_dir)
class BiTrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving model checkpoint to %s", output_dir)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not hasattr(self.model, 'save'):
raise NotImplementedError(
f'MODEL {self.model.__class__.__name__} '
f'does not support save interface')
else:
self.model.save(output_dir)
if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
# save the checkpoint for sentence-transformers library
if self.is_world_process_zero():
save_ckpt_for_sentence_transformers(output_dir,
pooling_mode=self.args.sentence_pooling_method,
normlized=self.args.normlized)
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
outputs = model(**inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
from .flag_models import FlagModel, LLMEmbedder
from .bge_m3 import BGEM3FlagModel
from .flag_reranker import FlagReranker, FlagLLMReranker, LayerWiseFlagLLMReranker
\ No newline at end of file
# Embedding Model
不同于其他使用均值池化的embedding模型,BGE使用`[cls]`的表征作为句子的embedding:`sentence_embeddings = model_output[0][:, 0]`
如果你使用均值池化,效果将会有显著的劣化。因此,一定要使用正确的方法来获取句子向量。您可以参考我们提供的使用方法。
## 推理
```python
from FlagEmbedding import FlagModel
sentences_1 = ["样例数据-1", "样例数据-2"]
sentences_2 = ["样例数据-3", "样例数据-4"]
model = FlagModel('BAAI/bge-large-zh-v1.5',
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
embeddings_1 = model.encode(sentences_1)
embeddings_2 = model.encode(sentences_2)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
# for s2p(short query to long passage) retrieval task, suggest to use encode_queries() which will automatically add the instruction to each query
# corpus in retrieval task can still use encode() or encode_corpus(), since they don't need instruction
queries = ['query_1', 'query_2']
passages = ["样例文档-1", "样例文档-2"]
q_embeddings = model.encode_queries(queries)
p_embeddings = model.encode(passages)
scores = q_embeddings @ p_embeddings.T
```
对于参数`query_instruction_for_retrieval`的值的设定,参考[Model List](https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list).
默认FlagModel在encoding的时候使用所有可使用的显卡,可以通过设置`os.environ["CUDA_VISIBLE_DEVICES"]`环境变量来制定显卡。同样的,设定`os.environ["CUDA_VISIBLE_DEVICES"]=""`表示所有显卡不可用。
### 使用 Sentence-Transformers
你也可以使用`bge`模型和 [sentence-transformers](https://www.SBERT.net):
```
pip install -U sentence-transformers
```
```python
from sentence_transformers import SentenceTransformer
sentences_1 = ["样例数据-1", "样例数据-2"]
sentences_2 = ["样例数据-3", "样例数据-4"]
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)
embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
```
针对s2p(short query to long passage)检索任务, 每个简短的查询语句需要有一个特定的指令开头(指令请参考[Model List](https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list))。但是对于长段落,制定不是必须的。
```python
from sentence_transformers import SentenceTransformer
queries = ['query_1', 'query_2']
passages = ["样例文档-1", "样例文档-2"]
instruction = "为这个句子生成表示以用于检索相关文章:"
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
q_embeddings = model.encode([instruction+q for q in queries], normalize_embeddings=True)
p_embeddings = model.encode(passages, normalize_embeddings=True)
scores = q_embeddings @ p_embeddings.T
```
### 使用 Langchain
你可以参考下面的方法在langchain中使用`bge`:
```python
from langchain.embeddings import HuggingFaceBgeEmbeddings
model_name = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device': 'cuda'}
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
model = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
query_instruction="为这个句子生成表示以用于检索相关文章:"
)
model.query_instruction = "为这个句子生成表示以用于检索相关文章:"
```
### 使用 HuggingFace Transformers
使用`transformers`包,你可以这样使用模型:首先,将输入传递给`Transformer`模型,然后选择第一个`token`(如[CLS]标记)的最后隐藏状态作为句子embedding。
```python
from transformers import AutoTokenizer, AutoModel
import torch
# Sentences we want sentence embeddings for
sentences = ["样例数据-1", "样例数据-2"]
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')
model = AutoModel.from_pretrained('BAAI/bge-large-zh-v1.5')
model.eval()
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# for s2p(short query to long passage) retrieval task, add an instruction to query (not add instruction for passages)
# encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling. In this case, cls pooling.
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
print("Sentence embeddings:", sentence_embeddings)
```
## 验证
`baai-general-embedding`模型达到了**在 MTEB 和 C-MTEB 排行榜上的SOTA效果!**
跟多验证细节和脚本可以参考[scripts](https://github.com/FlagOpen/FlagEmbedding/blob/master/C_MTEB/README.md)
如果你想使用**自有数据**去验证开原模型(或者自己的模型),你可以参考[这里](../../examples/finetune)
# Embedding Model
## Frequently asked questions
**The very poor results caused by incorrect usage**
Different from other embedding models using mean pooling, BGE uses the last hidden state of `[cls]` as the sentence embedding: `sentence_embeddings = model_output[0][:, 0]`.
If you use mean pooling, there will be a significant decrease in performance.
Therefore, make sure to use the correct method to obtain sentence vectors. You can refer to the usage method we provide.
**1. How to fine-tune bge embedding model?**
Following this [example](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune) to prepare data and fine-tune your model.
Some suggestions:
- Mine hard negatives following this [example](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune#hard-negatives), which can improve the retrieval performance.
- In general, larger hyper-parameter `per_device_train_batch_size` brings better performance. You can expand it by enabling `--fp16`, `--deepspeed df_config.json` (df_config.json can refer to [ds_config.json](https://github.com/FlagOpen/FlagEmbedding/blob/master/examples/finetune/ds_config.json)), `--gradient_checkpointing`, etc.
- If you want to maintain the performance on other tasks when fine-tuning on your data, you can use [LM-Cocktail](https://github.com/FlagOpen/FlagEmbedding/tree/master/LM_Cocktail) to merge the fine-tuned model and the original bge model. Besides, if you want to fine-tune on multiple tasks, you also can approximate the multi-task learning via model merging as [LM-Cocktail](https://github.com/FlagOpen/FlagEmbedding/tree/master/LM_Cocktail).
- If you pre-train bge on your data, the pre-trained model cannot be directly used to calculate similarity, and it must be fine-tuned with contrastive learning before computing similarity.
- If the accuracy of the fine-tuned model is still not high, it is recommended to use/fine-tune the cross-encoder model (bge-reranker) to re-rank top-k results. Hard negatives also are needed to fine-tune reranker.
Here is the way we used to fine-tune `bge-large-zh-v1.5`:
The fine-tuning datasets consist of t2ranking, dulreader, mmarco, cmedqav2, mulit-cpr, nli-zh, ocmnli, and cmnli.
For t2ranking, dulreader, and mmarco, we mine hard negatives;
For nli-zh, ocmnli, and cmnli, we use the pairs whose label equal to 0 as negatives;
For cmedqav2 and mulit-cpr, we randomly sample negatives.
The settings of fine-tuning are: train_group_size=2, learning_rate=1e-5, max_epoch=5.
We train two models: one fine-tune with `--query_instruction_for_retrieval "为这个句子生成表示以用于检索相关文章:"`,
and the other model is fine-tuned with `--query_instruction_for_retrieval ""`,
and then we merge two variants into one model to make the final model can be used both with and without instruction.
<details>
<summary>2. The similarity score between two dissimilar sentences is higher than 0.5</summary>
<!-- ### The similarity score between two dissimilar sentences is higher than 0.5 -->
**Suggest to use bge v1.5, which alleviates the issue of the similarity distribution.**
Since we finetune the models by contrastive learning with a temperature of 0.01,
the similarity distribution of the current BGE model is about in the interval \[0.6, 1\].
So a similarity score greater than 0.5 does not indicate that the two sentences are similar.
For downstream tasks, such as passage retrieval or semantic similarity,
**what matters is the relative order of the scores, not the absolute value.**
If you need to filter similar sentences based on a similarity threshold,
please select an appropriate similarity threshold based on the similarity distribution on your data (such as 0.8, 0.85, or even 0.9).
</details>
<details>
<summary>3. When does the query instruction need to be used</summary>
<!-- ### When does the query instruction need to be used -->
For the `bge-*-v1.5`, we improve its retrieval ability when not using instruction.
No instruction only has a slight degradation in retrieval performance compared with using instruction.
So you can generate embedding without instruction in all cases for convenience.
For a retrieval task that uses short queries to find long related documents,
it is recommended to add instructions for these short queries.
**The best method to decide whether to add instructions for queries is choosing the setting that achieves better performance on your task.**
In all cases, the documents/passages do not need to add the instruction.
</details>
## Usage
### Using FlagEmbedding
Install:
```
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding
pip install -e .
```
or:
```
pip install -U FlagEmbedding
```
```python
from FlagEmbedding import FlagModel
sentences_1 = ["样例数据-1", "样例数据-2"]
sentences_2 = ["样例数据-3", "样例数据-4"]
model = FlagModel('BAAI/bge-large-zh-v1.5',
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
embeddings_1 = model.encode(sentences_1)
embeddings_2 = model.encode(sentences_2)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
# for s2p(short query to long passage) retrieval task, suggest to use encode_queries() which will automatically add the instruction to each query
# corpus in retrieval task can still use encode() or encode_corpus(), since they don't need instruction
queries = ['query_1', 'query_2']
passages = ["样例文档-1", "样例文档-2"]
q_embeddings = model.encode_queries(queries)
p_embeddings = model.encode(passages)
scores = q_embeddings @ p_embeddings.T
```
For the value of the argument `query_instruction_for_retrieval`, see [Model List](https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list).
By default, FlagModel will use all available GPUs when encoding. Please set `os.environ["CUDA_VISIBLE_DEVICES"]` to select specific GPUs.
You also can set `os.environ["CUDA_VISIBLE_DEVICES"]=""` to make all GPUs unavailable.
### Using Sentence-Transformers
You can also use the `bge` models with [sentence-transformers](https://www.SBERT.net):
```
pip install -U sentence-transformers
```
```python
from sentence_transformers import SentenceTransformer
sentences_1 = ["样例数据-1", "样例数据-2"]
sentences_2 = ["样例数据-3", "样例数据-4"]
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
embeddings_1 = model.encode(sentences_1, normalize_embeddings=True)
embeddings_2 = model.encode(sentences_2, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
```
For s2p(short query to long passage) retrieval task,
each short query should start with an instruction (instructions see [Model List](https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list)).
But the instruction is not needed for passages.
```python
from sentence_transformers import SentenceTransformer
queries = ['query_1', 'query_2']
passages = ["样例文档-1", "样例文档-2"]
instruction = "为这个句子生成表示以用于检索相关文章:"
model = SentenceTransformer('BAAI/bge-large-zh-v1.5')
q_embeddings = model.encode([instruction+q for q in queries], normalize_embeddings=True)
p_embeddings = model.encode(passages, normalize_embeddings=True)
scores = q_embeddings @ p_embeddings.T
```
### Using Langchain
You can use `bge` in langchain like this:
```python
from langchain.embeddings import HuggingFaceBgeEmbeddings
model_name = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device': 'cuda'}
encode_kwargs = {'normalize_embeddings': True} # set True to compute cosine similarity
model = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
query_instruction="为这个句子生成表示以用于检索相关文章:"
)
model.query_instruction = "为这个句子生成表示以用于检索相关文章:"
```
### Using HuggingFace Transformers
With the transformers package, you can use the model like this: First, you pass your input through the transformer model, then you select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding.
```python
from transformers import AutoTokenizer, AutoModel
import torch
# Sentences we want sentence embeddings for
sentences = ["样例数据-1", "样例数据-2"]
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')
model = AutoModel.from_pretrained('BAAI/bge-large-zh-v1.5')
model.eval()
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# for s2p(short query to long passage) retrieval task, add an instruction to query (not add instruction for passages)
# encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling. In this case, cls pooling.
sentence_embeddings = model_output[0][:, 0]
# normalize embeddings
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
print("Sentence embeddings:", sentence_embeddings)
```
## Evaluation
`baai-general-embedding` models achieve **state-of-the-art performance on both MTEB and C-MTEB leaderboard!**
For more details and evaluation tools see our [scripts](https://github.com/FlagOpen/FlagEmbedding/blob/master/C_MTEB/README.md)
If you want to evaluate the model(or your model) on **your data**, you can refer to this [tool](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune#6-evaluate-model).
- **MTEB**:
| Model Name | Dimension | Sequence Length | Average (56) | Retrieval (15) |Clustering (11) | Pair Classification (3) | Reranking (4) | STS (10) | Summarization (1) | Classification (12) |
|:----:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
| [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5) | 1024 | 512 | **64.23** | **54.29** | 46.08 | 87.12 | 60.03 | 83.11 | 31.61 | 75.97 |
| [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) | 768 | 512 | 63.55 | 53.25 | 45.77 | 86.55 | 58.86 | 82.4 | 31.07 | 75.53 |
| [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) | 384 | 512 | 62.17 |51.68 | 43.82 | 84.92 | 58.36 | 81.59 | 30.12 | 74.14 |
| [bge-large-en](https://huggingface.co/BAAI/bge-large-en) | 1024 | 512 | 63.98 | 53.9 | 46.98 | 85.8 | 59.48 | 81.56 | 32.06 | 76.21 |
| [bge-base-en](https://huggingface.co/BAAI/bge-base-en) | 768 | 512 | 63.36 | 53.0 | 46.32 | 85.86 | 58.7 | 81.84 | 29.27 | 75.27 |
| [gte-large](https://huggingface.co/thenlper/gte-large) | 1024 | 512 | 63.13 | 52.22 | 46.84 | 85.00 | 59.13 | 83.35 | 31.66 | 73.33 |
| [gte-base](https://huggingface.co/thenlper/gte-base) | 768 | 512 | 62.39 | 51.14 | 46.2 | 84.57 | 58.61 | 82.3 | 31.17 | 73.01 |
| [e5-large-v2](https://huggingface.co/intfloat/e5-large-v2) | 1024| 512 | 62.25 | 50.56 | 44.49 | 86.03 | 56.61 | 82.05 | 30.19 | 75.24 |
| [bge-small-en](https://huggingface.co/BAAI/bge-small-en) | 384 | 512 | 62.11 | 51.82 | 44.31 | 83.78 | 57.97 | 80.72 | 30.53 | 74.37 |
| [instructor-xl](https://huggingface.co/hkunlp/instructor-xl) | 768 | 512 | 61.79 | 49.26 | 44.74 | 86.62 | 57.29 | 83.06 | 32.32 | 61.79 |
| [e5-base-v2](https://huggingface.co/intfloat/e5-base-v2) | 768 | 512 | 61.5 | 50.29 | 43.80 | 85.73 | 55.91 | 81.05 | 30.28 | 73.84 |
| [gte-small](https://huggingface.co/thenlper/gte-small) | 384 | 512 | 61.36 | 49.46 | 44.89 | 83.54 | 57.7 | 82.07 | 30.42 | 72.31 |
| [text-embedding-ada-002](https://platform.openai.com/docs/guides/embeddings) | 1536 | 8192 | 60.99 | 49.25 | 45.9 | 84.89 | 56.32 | 80.97 | 30.8 | 70.93 |
| [e5-small-v2](https://huggingface.co/intfloat/e5-base-v2) | 384 | 512 | 59.93 | 49.04 | 39.92 | 84.67 | 54.32 | 80.39 | 31.16 | 72.94 |
| [sentence-t5-xxl](https://huggingface.co/sentence-transformers/sentence-t5-xxl) | 768 | 512 | 59.51 | 42.24 | 43.72 | 85.06 | 56.42 | 82.63 | 30.08 | 73.42 |
| [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2) | 768 | 514 | 57.78 | 43.81 | 43.69 | 83.04 | 59.36 | 80.28 | 27.49 | 65.07 |
| [sgpt-bloom-7b1-msmarco](https://huggingface.co/bigscience/sgpt-bloom-7b1-msmarco) | 4096 | 2048 | 57.59 | 48.22 | 38.93 | 81.9 | 55.65 | 77.74 | 33.6 | 66.19 |
- **C-MTEB**:
We create the benchmark C-MTEB for Chinese text embedding which consists of 31 datasets from 6 tasks.
Please refer to [C_MTEB](https://github.com/FlagOpen/FlagEmbedding/blob/master/C_MTEB/README.md) for a detailed introduction.
| Model | Embedding dimension | Avg | Retrieval | STS | PairClassification | Classification | Reranking | Clustering |
|:-------------------------------|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
| [**BAAI/bge-large-zh-v1.5**](https://huggingface.co/BAAI/bge-large-zh-v1.5) | 1024 | **64.53** | 70.46 | 56.25 | 81.6 | 69.13 | 65.84 | 48.99 |
| [BAAI/bge-base-zh-v1.5](https://huggingface.co/BAAI/bge-base-zh-v1.5) | 768 | 63.13 | 69.49 | 53.72 | 79.75 | 68.07 | 65.39 | 47.53 |
| [BAAI/bge-small-zh-v1.5](https://huggingface.co/BAAI/bge-small-zh-v1.5) | 512 | 57.82 | 61.77 | 49.11 | 70.41 | 63.96 | 60.92 | 44.18 |
| [BAAI/bge-large-zh](https://huggingface.co/BAAI/bge-large-zh) | 1024 | 64.20 | 71.53 | 54.98 | 78.94 | 68.32 | 65.11 | 48.39 |
| [bge-large-zh-noinstruct](https://huggingface.co/BAAI/bge-large-zh-noinstruct) | 1024 | 63.53 | 70.55 | 53 | 76.77 | 68.58 | 64.91 | 50.01 |
| [BAAI/bge-base-zh](https://huggingface.co/BAAI/bge-base-zh) | 768 | 62.96 | 69.53 | 54.12 | 77.5 | 67.07 | 64.91 | 47.63 |
| [multilingual-e5-large](https://huggingface.co/intfloat/multilingual-e5-large) | 1024 | 58.79 | 63.66 | 48.44 | 69.89 | 67.34 | 56.00 | 48.23 |
| [BAAI/bge-small-zh](https://huggingface.co/BAAI/bge-small-zh) | 512 | 58.27 | 63.07 | 49.45 | 70.35 | 63.64 | 61.48 | 45.09 |
| [m3e-base](https://huggingface.co/moka-ai/m3e-base) | 768 | 57.10 | 56.91 | 50.47 | 63.99 | 67.52 | 59.34 | 47.68 |
| [m3e-large](https://huggingface.co/moka-ai/m3e-large) | 1024 | 57.05 | 54.75 | 50.42 | 64.3 | 68.2 | 59.66 | 48.88 |
| [multilingual-e5-base](https://huggingface.co/intfloat/multilingual-e5-base) | 768 | 55.48 | 61.63 | 46.49 | 67.07 | 65.35 | 54.35 | 40.68 |
| [multilingual-e5-small](https://huggingface.co/intfloat/multilingual-e5-small) | 384 | 55.38 | 59.95 | 45.27 | 66.45 | 65.85 | 53.86 | 45.26 |
| [text-embedding-ada-002(OpenAI)](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings) | 1536 | 53.02 | 52.0 | 43.35 | 69.56 | 64.31 | 54.28 | 45.68 |
| [luotuo](https://huggingface.co/silk-road/luotuo-bert-medium) | 1024 | 49.37 | 44.4 | 42.78 | 66.62 | 61 | 49.25 | 44.39 |
| [text2vec-base](https://huggingface.co/shibing624/text2vec-base-chinese) | 768 | 47.63 | 38.79 | 43.41 | 67.41 | 62.19 | 49.45 | 37.66 |
| [text2vec-large](https://huggingface.co/GanymedeNil/text2vec-large-chinese) | 1024 | 47.36 | 41.94 | 44.97 | 70.86 | 60.66 | 49.16 | 30.02 |
## Acknowledgement
Part of the code is developed based on [Dense](https://github.com/luyug/Dense).
## Citation
If you find this repository useful, please consider giving a star :star: and citation
```
@misc{bge_embedding,
title={C-Pack: Packaged Resources To Advance General Chinese Embedding},
author={Shitao Xiao and Zheng Liu and Peitian Zhang and Niklas Muennighoff},
year={2023},
eprint={2309.07597},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
from .modeling import BiEncoderModel, EncoderOutput
from .trainer import BiTrainer
import os
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
@dataclass
class DataArguments:
train_data: str = field(
default=None, metadata={"help": "Path to train data"}
)
train_group_size: int = field(default=8)
query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
passage_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_example_num_per_dataset: int = field(
default=100000000, metadata={"help": "the max number of examples for each dataset"}
)
query_instruction_for_retrieval: str= field(
default=None, metadata={"help": "instruction for query"}
)
passage_instruction_for_retrieval: str = field(
default=None, metadata={"help": "instruction for passage"}
)
def __post_init__(self):
if not os.path.exists(self.train_data):
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")
@dataclass
class RetrieverTrainingArguments(TrainingArguments):
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
temperature: Optional[float] = field(default=0.02)
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
normlized: bool = field(default=True)
use_inbatch_neg: bool = field(default=True, metadata={"help": "use passages in the same batch as negatives"})
import math
import os.path
import random
from dataclasses import dataclass
from typing import List, Tuple
import datasets
from torch.utils.data import Dataset
from transformers import DataCollatorWithPadding, PreTrainedTokenizer
from .arguments import DataArguments
class TrainDatasetForEmbedding(Dataset):
def __init__(
self,
args: DataArguments,
tokenizer: PreTrainedTokenizer
):
if os.path.isdir(args.train_data):
train_datasets = []
for file in os.listdir(args.train_data):
temp_dataset = datasets.load_dataset('json', data_files=os.path.join(args.train_data, file),
split='train')
if len(temp_dataset) > args.max_example_num_per_dataset:
temp_dataset = temp_dataset.select(
random.sample(list(range(len(temp_dataset))), args.max_example_num_per_dataset))
train_datasets.append(temp_dataset)
self.dataset = datasets.concatenate_datasets(train_datasets)
else:
self.dataset = datasets.load_dataset('json', data_files=args.train_data, split='train')
self.tokenizer = tokenizer
self.args = args
self.total_len = len(self.dataset)
def __len__(self):
return self.total_len
def __getitem__(self, item) -> Tuple[str, List[str]]:
query = self.dataset[item]['query']
if self.args.query_instruction_for_retrieval is not None:
query = self.args.query_instruction_for_retrieval + query
passages = []
assert isinstance(self.dataset[item]['pos'], list)
pos = random.choice(self.dataset[item]['pos'])
passages.append(pos)
if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
else:
negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
passages.extend(negs)
if self.args.passage_instruction_for_retrieval is not None:
passages = [self.args.passage_instruction_for_retrieval+p for p in passages]
return query, passages
@dataclass
class EmbedCollator(DataCollatorWithPadding):
"""
Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
and pass batch separately to the actual collator.
Abstract out data detail for the model.
"""
query_max_len: int = 32
passage_max_len: int = 128
def padding_score(self, teacher_score):
group_size = None
for scores in teacher_score:
if scores is not None:
group_size = len(scores)
break
if group_size is None:
return None
padding_scores = [100.0] + [0.0] * (group_size - 1)
new_teacher_score = []
for scores in teacher_score:
if scores is None:
new_teacher_score.append(padding_scores)
else:
new_teacher_score.append(scores)
return new_teacher_score
def __call__(self, features):
query = [f[0] for f in features]
passage = [f[1] for f in features]
if isinstance(query[0], list):
query = sum(query, [])
if isinstance(passage[0], list):
passage = sum(passage, [])
q_collated = self.tokenizer(
query,
padding=True,
truncation=True,
max_length=self.query_max_len,
return_tensors="pt",
)
d_collated = self.tokenizer(
passage,
padding=True,
truncation=True,
max_length=self.passage_max_len,
return_tensors="pt",
)
return {"query": q_collated, "passage": d_collated}
import faiss
import torch
import logging
import datasets
import numpy as np
from tqdm import tqdm
from typing import Optional
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from FlagEmbedding import FlagModel
logger = logging.getLogger(__name__)
@dataclass
class Args:
encoder: str = field(
default="BAAI/bge-base-en-v1.5",
metadata={'help': 'The encoder name or path.'}
)
fp16: bool = field(
default=False,
metadata={'help': 'Use fp16 in inference?'}
)
add_instruction: bool = field(
default=False,
metadata={'help': 'Add query-side instruction?'}
)
corpus_data: str = field(
default="namespace-Pt/msmarco",
metadata={'help': 'candidate passages'}
)
query_data: str = field(
default="namespace-Pt/msmarco-corpus",
metadata={'help': 'queries and their positive passages for evaluation'}
)
max_query_length: int = field(
default=32,
metadata={'help': 'Max query length.'}
)
max_passage_length: int = field(
default=128,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=256,
metadata={'help': 'Inference batch size.'}
)
index_factory: str = field(
default="Flat",
metadata={'help': 'Faiss index factory.'}
)
k: int = field(
default=100,
metadata={'help': 'How many neighbors to retrieve?'}
)
save_embedding: bool = field(
default=False,
metadata={'help': 'Save embeddings in memmap at save_dir?'}
)
load_embedding: bool = field(
default=False,
metadata={'help': 'Load embeddings from save_dir?'}
)
save_path: str = field(
default="embeddings.memmap",
metadata={'help': 'Path to save embeddings.'}
)
def index(model: FlagModel, corpus: datasets.Dataset, batch_size: int = 256, max_length: int=512, index_factory: str = "Flat", save_path: str = None, save_embedding: bool = False, load_embedding: bool = False):
"""
1. Encode the entire corpus into dense embeddings;
2. Create faiss index;
3. Optionally save embeddings.
"""
if load_embedding:
test = model.encode("test")
dtype = test.dtype
dim = len(test)
corpus_embeddings = np.memmap(
save_path,
mode="r",
dtype=dtype
).reshape(-1, dim)
else:
corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_length)
dim = corpus_embeddings.shape[-1]
if save_embedding:
logger.info(f"saving embeddings at {save_path}...")
memmap = np.memmap(
save_path,
shape=corpus_embeddings.shape,
mode="w+",
dtype=corpus_embeddings.dtype
)
length = corpus_embeddings.shape[0]
# add in batch
save_batch_size = 10000
if length > save_batch_size:
for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"):
j = min(i + save_batch_size, length)
memmap[i: j] = corpus_embeddings[i: j]
else:
memmap[:] = corpus_embeddings
# create faiss index
faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT)
if model.device == torch.device("cuda"):
# co = faiss.GpuClonerOptions()
co = faiss.GpuMultipleClonerOptions()
co.useFloat16 = True
# faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co)
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co)
# NOTE: faiss only accepts float32
logger.info("Adding embeddings...")
corpus_embeddings = corpus_embeddings.astype(np.float32)
faiss_index.train(corpus_embeddings)
faiss_index.add(corpus_embeddings)
return faiss_index
def search(model: FlagModel, queries: datasets, faiss_index: faiss.Index, k:int = 100, batch_size: int = 256, max_length: int=512):
"""
1. Encode queries into dense embeddings;
2. Search through faiss index
"""
query_embeddings = model.encode_queries(queries["query"], batch_size=batch_size, max_length=max_length)
query_size = len(query_embeddings)
all_scores = []
all_indices = []
for i in tqdm(range(0, query_size, batch_size), desc="Searching"):
j = min(i + batch_size, query_size)
query_embedding = query_embeddings[i: j]
score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k)
all_scores.append(score)
all_indices.append(indice)
all_scores = np.concatenate(all_scores, axis=0)
all_indices = np.concatenate(all_indices, axis=0)
return all_scores, all_indices
def evaluate(preds,
preds_scores,
labels,
cutoffs=[1, 10, 100]):
"""
Evaluate MRR and Recall at cutoffs.
"""
metrics = {}
# MRR
mrrs = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
jump = False
for i, x in enumerate(pred, 1):
if x in label:
for k, cutoff in enumerate(cutoffs):
if i <= cutoff:
mrrs[k] += 1 / i
jump = True
if jump:
break
mrrs /= len(preds)
for i, cutoff in enumerate(cutoffs):
mrr = mrrs[i]
metrics[f"MRR@{cutoff}"] = mrr
# Recall
recalls = np.zeros(len(cutoffs))
for pred, label in zip(preds, labels):
for k, cutoff in enumerate(cutoffs):
recall = np.intersect1d(label, pred[:cutoff])
recalls[k] += len(recall) / len(label)
recalls /= len(preds)
for i, cutoff in enumerate(cutoffs):
recall = recalls[i]
metrics[f"Recall@{cutoff}"] = recall
# AUC
pred_hard_encodings = []
for pred, label in zip(preds, labels):
pred_hard_encoding = np.isin(pred, label).astype(int).tolist()
pred_hard_encodings.append(pred_hard_encoding)
from sklearn.metrics import roc_auc_score, roc_curve, ndcg_score
pred_hard_encodings1d = np.asarray(pred_hard_encodings).flatten()
preds_scores1d = preds_scores.flatten()
auc = roc_auc_score(pred_hard_encodings1d, preds_scores1d)
metrics['AUC@100'] = auc
# nDCG
for k, cutoff in enumerate(cutoffs):
nDCG = ndcg_score(pred_hard_encodings, preds_scores, k=cutoff)
metrics[f"nDCG@{cutoff}"] = nDCG
return metrics
def main():
parser = HfArgumentParser([Args])
args: Args = parser.parse_args_into_dataclasses()[0]
if args.query_data == 'namespace-Pt/msmarco-corpus':
assert args.corpus_data == 'namespace-Pt/msmarco'
eval_data = datasets.load_dataset("namespace-Pt/msmarco", split="dev")
corpus = datasets.load_dataset("namespace-Pt/msmarco-corpus", split="train")
else:
eval_data = datasets.load_dataset('json', data_files=args.query_data, split='train')
corpus = datasets.load_dataset('json', data_files=args.corpus_data, split='train')
model = FlagModel(
args.encoder,
query_instruction_for_retrieval="Represent this sentence for searching relevant passages: " if args.add_instruction else None,
use_fp16=args.fp16
)
faiss_index = index(
model=model,
corpus=corpus,
batch_size=args.batch_size,
max_length=args.max_passage_length,
index_factory=args.index_factory,
save_path=args.save_path,
save_embedding=args.save_embedding,
load_embedding=args.load_embedding
)
scores, indices = search(
model=model,
queries=eval_data,
faiss_index=faiss_index,
k=args.k,
batch_size=args.batch_size,
max_length=args.max_query_length
)
retrieval_results = []
for indice in indices:
# filter invalid indices
indice = indice[indice != -1].tolist()
retrieval_results.append(corpus[indice]["content"])
ground_truths = []
for sample in eval_data:
ground_truths.append(sample["positive"])
metrics = evaluate(retrieval_results, scores, ground_truths)
print(metrics)
if __name__ == "__main__":
main()
import argparse
import json
import random
import numpy as np
import faiss
from tqdm import tqdm
from FlagEmbedding import FlagModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default="BAAI/bge-base-en", type=str)
parser.add_argument('--input_file', default=None, type=str)
parser.add_argument('--candidate_pool', default=None, type=str)
parser.add_argument('--output_file', default=None, type=str)
parser.add_argument('--range_for_sampling', default="10-210", type=str, help="range to sample negatives")
parser.add_argument('--use_gpu_for_searching', action='store_true', help='use faiss-gpu')
parser.add_argument('--negative_number', default=15, type=int, help='the number of negatives')
parser.add_argument('--query_instruction_for_retrieval', default="")
return parser.parse_args()
def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index
def batch_search(index,
query,
topk: int = 200,
batch_size: int = 64):
all_scores, all_inxs = [], []
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
batch_query = query[start_index:start_index + batch_size]
batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk)
all_scores.extend(batch_scores.tolist())
all_inxs.extend(batch_inxs.tolist())
return all_scores, all_inxs
def get_corpus(candidate_pool):
corpus = []
for line in open(candidate_pool):
line = json.loads(line.strip())
corpus.append(line['text'])
return corpus
def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu):
corpus = []
queries = []
train_data = []
for line in open(input_file):
line = json.loads(line.strip())
train_data.append(line)
corpus.extend(line['pos'])
if 'neg' in line:
corpus.extend(line['neg'])
queries.append(line['query'])
if candidate_pool is not None:
if not isinstance(candidate_pool, list):
candidate_pool = get_corpus(candidate_pool)
corpus = list(set(candidate_pool))
else:
corpus = list(set(corpus))
print(f'inferencing embedding for corpus (number={len(corpus)})--------------')
p_vecs = model.encode(corpus, batch_size=256)
print(f'inferencing embedding for queries (number={len(queries)})--------------')
q_vecs = model.encode_queries(queries, batch_size=256)
print('create index and search------------------')
index = create_index(p_vecs, use_gpu=use_gpu)
_, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1])
assert len(all_inxs) == len(train_data)
for i, data in enumerate(train_data):
query = data['query']
inxs = all_inxs[i][sample_range[0]:sample_range[1]]
filtered_inx = []
for inx in inxs:
if inx == -1: break
if corpus[inx] not in data['pos'] and corpus[inx] != query:
filtered_inx.append(inx)
if len(filtered_inx) > negative_number:
filtered_inx = random.sample(filtered_inx, negative_number)
data['neg'] = [corpus[inx] for inx in filtered_inx]
with open(output_file, 'w') as f:
for data in train_data:
if len(data['neg']) < negative_number:
samples = random.sample(corpus, negative_number - len(data['neg']) + len(data['pos']))
samples = [sent for sent in samples if sent not in data['pos']]
data['neg'].extend(samples[: negative_number - len(data['neg'])])
f.write(json.dumps(data, ensure_ascii=False) + '\n')
if __name__ == '__main__':
args = get_args()
sample_range = args.range_for_sampling.split('-')
sample_range = [int(x) for x in sample_range]
model = FlagModel(args.model_name_or_path, query_instruction_for_retrieval=args.query_instruction_for_retrieval)
find_knn_neg(model,
input_file=args.input_file,
candidate_pool=args.candidate_pool,
output_file=args.output_file,
sample_range=sample_range,
negative_number=args.negative_number,
use_gpu=args.use_gpu_for_searching)
import logging
from dataclasses import dataclass
from typing import Dict, Optional
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import AutoModel
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
@dataclass
class EncoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
p_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class BiEncoderModel(nn.Module):
TRANSFORMER_CLS = AutoModel
def __init__(self,
model_name: str = None,
normlized: bool = False,
sentence_pooling_method: str = 'cls',
negatives_cross_device: bool = False,
temperature: float = 1.0,
use_inbatch_neg: bool = True
):
super().__init__()
self.model = AutoModel.from_pretrained(model_name)
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.normlized = normlized
self.sentence_pooling_method = sentence_pooling_method
self.temperature = temperature
self.use_inbatch_neg = use_inbatch_neg
self.config = self.model.config
if not normlized:
self.temperature = 1.0
logger.info("reset temperature = 1.0 due to using inner product to compute similarity")
if normlized:
if self.temperature > 0.5:
raise ValueError("Temperature should be smaller than 1.0 when use cosine similarity (i.e., normlized=True). Recommend to set it 0.01-0.1")
self.negatives_cross_device = negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError('Distributed training has not been initialized for representation all gather.')
# logger.info("Run in a single GPU, set negatives_cross_device=False")
# self.negatives_cross_device = False
# else:
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def sentence_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == 'cls':
return hidden_state[:, 0]
def encode(self, features):
if features is None:
return None
psg_out = self.model(**features, return_dict=True)
p_reps = self.sentence_embedding(psg_out.last_hidden_state, features['attention_mask'])
if self.normlized:
p_reps = torch.nn.functional.normalize(p_reps, dim=-1)
return p_reps.contiguous()
def compute_similarity(self, q_reps, p_reps):
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def forward(self, query: Dict[str, Tensor] = None, passage: Dict[str, Tensor] = None, teacher_score: Tensor = None):
q_reps = self.encode(query)
p_reps = self.encode(passage)
if self.training:
if self.negatives_cross_device and self.use_inbatch_neg:
q_reps = self._dist_gather_tensor(q_reps)
p_reps = self._dist_gather_tensor(p_reps)
group_size = p_reps.size(0) // q_reps.size(0)
if self.use_inbatch_neg:
scores = self.compute_similarity(q_reps, p_reps) / self.temperature # B B*G
scores = scores.view(q_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * group_size
loss = self.compute_loss(scores, target)
else:
scores = self.compute_similarity(q_reps[:, None, :,], p_reps.view(q_reps.size(0), group_size, -1)).squeeze(1) / self.temperature # B G
scores = scores.view(q_reps.size(0), -1)
target = torch.zeros(scores.size(0), device=scores.device, dtype=torch.long)
loss = self.compute_loss(scores, target)
else:
scores = self.compute_similarity(q_reps, p_reps)
loss = None
return EncoderOutput(
loss=loss,
scores=scores,
q_reps=q_reps,
p_reps=p_reps,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def save(self, output_dir: str):
state_dict = self.model.state_dict()
state_dict = type(state_dict)(
{k: v.clone().cpu()
for k,
v in state_dict.items()})
self.model.save_pretrained(output_dir, state_dict=state_dict)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment