Commit f75058c7 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add.

parents
Pipeline #1411 canceled with stages
import os
import logging
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers.file_utils import ModelOutput
from FlagEmbedding.visual.eva_clip import create_eva_vision_and_transforms
from PIL import Image
logger = logging.getLogger(__name__)
@dataclass
class EncoderOutput(ModelOutput):
q_reps: Optional[Tensor] = None
c_reps: Optional[Tensor] = None
loss: Optional[Tensor] = None
scores: Optional[Tensor] = None
class Visualized_BGE(nn.Module):
def __init__(self,
model_name_bge: str = None,
model_weight = None, # "/path/to/your/weight/file/"
normlized: bool = True,
sentence_pooling_method: str = 'cls',
negatives_cross_device: bool = False,
temperature: float = 0.02, # 1.0
from_pretrained=None, # local config file and model
):
super().__init__()
assert model_name_bge in ["BAAI/bge-base-en-v1.5", "BAAI/bge-m3"]
assert model_weight is not None
self.model_name_bge = model_name_bge
if model_name_bge == 'BAAI/bge-base-en-v1.5':
model_name_eva = "EVA02-CLIP-B-16"
self.hidden_dim = 768
self.depth = 12
elif model_name_bge == 'BAAI/bge-m3':
model_name_eva = "EVA02-CLIP-L-14"
self.hidden_dim = 1024
self.depth = 24
if not from_pretrained:
bge_config = AutoConfig.from_pretrained(model_name_bge)
bge = AutoModel.from_config(bge_config)
else:
print("Loading from local path.")
bge_config = AutoConfig.from_pretrained(from_pretrained, local_files_only=True)
bge = AutoModel.from_config(bge_config)
self.bge_encoder = bge.encoder
self.bge_embeddings = bge.embeddings
self.bge_pooler = bge.pooler
self.model_visual, self.preprocess_train, self.preprocess_val= create_eva_vision_and_transforms(
model_name_eva,
force_custom_clip=True)
self.visual_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.normlized = normlized
self.sentence_pooling_method = sentence_pooling_method
self.temperature = temperature
if not normlized:
self.temperature = 1.0
logger.info("reset temperature = 1.0 due to using inner product to compute similarity")
self.negatives_cross_device = negatives_cross_device
if self.negatives_cross_device:
if not dist.is_initialized():
raise ValueError('Distributed training has not been initialized for representation all gather.')
self.process_rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.load_model(model_weight)
if not from_pretrained:
self.tokenizer = AutoTokenizer.from_pretrained(model_name_bge, use_fast=False)
else:
self.tokenizer = AutoTokenizer.from_pretrained(from_pretrained, use_fast=False)
if torch.cuda.is_available():
self.device = torch.device('cuda')
self.to(self.device)
else:
self.device = torch.device('cpu')
def load_model(self, model_weight):
self.load_state_dict(torch.load(model_weight, map_location='cpu'))
def gradient_checkpointing_enable(self, **kwargs):
# self.bge_encoder.gradient_checkpointing_enable()
self.model_visual.set_grad_checkpointing(True)
def encode(self, image=None, text=None):
# used for simple inference
if image is not None:
image = self.preprocess_val(Image.open(image)).unsqueeze(0)
if text is not None:
text = self.tokenizer(text, return_tensors="pt", padding=True)
return self.encode_mm(image.to(self.device), text.to(self.device))
else:
return self.encode_image(image.to(self.device))
else:
if text is not None:
text = self.tokenizer(text, return_tensors="pt", padding=True)
return self.encode_text(text.to(self.device))
else:
return None
def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = torch.float16
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Arguments:
attention_mask (`torch.Tensor`):
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
input_shape (`Tuple[int]`):
The shape of the input to the model.
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = attention_mask[:, None, None, :]
else:
raise ValueError(
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
return extended_attention_mask
def sentence_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
elif self.sentence_pooling_method == 'cls':
return hidden_state[:, 0]
def encode_text(self, texts):
'''
encode text only
'''
input_ids = texts['input_ids']
attention_mask = texts['attention_mask']
input_shape = input_ids.size()
device = input_ids.device
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
head_mask = [None] * self.depth
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.bge_embeddings(
input_ids=input_ids,
position_ids=None,
token_type_ids=token_type_ids,
inputs_embeds=None,
past_key_values_length=0,
)
encoder_outputs = self.bge_encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
sequence_output = encoder_outputs[0]
# pooled_output = self.bge_pooler(sequence_output) if self.bge_pooler is not None else None
t_reps = self.sentence_embedding(sequence_output, texts['attention_mask']) # tensor: reps with pooling
if self.normlized:
t_reps = torch.nn.functional.normalize(t_reps, dim=-1)
return t_reps.contiguous()
def encode_mm(self, images:torch.Tensor, texts):
img_token_emb = self.img_token_embedding(images) #[B, Patch_num, C]
img_token_emb = img_token_emb[:,1:] # img_cls is not used here
img_token_emb = self.visual_proj(img_token_emb)
device = img_token_emb.device
img_token_len = img_token_emb.size()[1]
# image position embedding, default position: bge_cls + img tokens + texts
img_token_position_ids = torch.arange(1, 1 + img_token_len).to(device=device)
img_position_embeddings = self.bge_embeddings.position_embeddings(img_token_position_ids)
img_token_emb = img_token_emb + img_position_embeddings
img_token_emb = self.bge_embeddings.LayerNorm(img_token_emb)
### deal with prompt/text
prompt_input_ids = texts['input_ids']
prompt_attention_mask = texts['attention_mask']
prom_input_shape = prompt_input_ids.size()
# bert
batch_size = prom_input_shape[0]
prompt_len = prom_input_shape[1]
prompt_start = 1 + img_token_len
cls_id = torch.tensor([0]).to(device=device)
prompt_position_ids = torch.arange(prompt_start, prompt_start + prompt_len - 1).to(device=device)
prompt_position_ids = torch.cat([cls_id, prompt_position_ids]).to(device=device)
prompt_token_type_ids = torch.zeros(prom_input_shape, dtype=torch.long, device=device)
prompt_embedding_output = self.bge_embeddings(
input_ids=prompt_input_ids,
position_ids=prompt_position_ids,
token_type_ids=prompt_token_type_ids,
inputs_embeds=None,
past_key_values_length=0,
) # [B, T, C]
cls_token = prompt_embedding_output[:, 0:1, :] # bge_cls token
prompt_embedding_output = prompt_embedding_output[:, 1:]
prompt_img_embedding = torch.cat([cls_token, img_token_emb, prompt_embedding_output], dim=1)
img_attention_mask = torch.ones(batch_size, img_token_len, device=device)
prom_img_attention_mask = torch.cat([img_attention_mask, prompt_attention_mask], dim=1)
prom_img_input_shape = prompt_img_embedding.size()
head_mask = [None] * self.depth
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(prom_img_attention_mask, prom_img_input_shape).to(prompt_img_embedding.dtype)
encoder_outputs = self.bge_encoder(
prompt_img_embedding,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
sequence_output = encoder_outputs[0]
prompt_img_reps = self.sentence_embedding(sequence_output, prom_img_attention_mask) # tensor: reps with pooling
if self.normlized:
prompt_img_reps = torch.nn.functional.normalize(prompt_img_reps, dim=-1)
return prompt_img_reps
def compute_similarity(self, q_reps, p_reps):
if len(p_reps.size()) == 2:
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
def img_token_embedding(self, images):
if images is None:
return None
img_token_emb = self.model_visual.encode_image(images, normalize=False) # return_all_features=True, [B, Patch_num, C]
return img_token_emb.contiguous()
def encode_image(self, images):
if images is None:
return None
batch_size = images.shape[0]
prompts = [""] * batch_size
prompts = self.tokenizer(prompts, return_tensors="pt", padding=True)
prompts = prompts.to(images.device)
img_reps = self.encode_mm(images, prompts)
return img_reps
def forward(self, mm_it_query=None, image_candidate=None, text_candidate=None, text_query=None, mm_it_candidate=None, task_type=None):
### for stage-2 training
if task_type == "edit_image":
mm_query_reps = self.encode_mm(mm_it_query[0], mm_it_query[1])
image_candi_reps = self.encode_image(image_candidate)
query_reps = mm_query_reps
candi_reps = image_candi_reps
elif task_type == "t2it":
text_query_reps = self.encode_text(text_query)
mmit_candi_reps = self.encode_mm(mm_it_candidate[0], mm_it_candidate[1])
query_reps = text_query_reps
candi_reps = mmit_candi_reps
if self.training:
if self.negatives_cross_device:
query_reps = self._dist_gather_tensor(query_reps)
candi_reps = self._dist_gather_tensor(candi_reps)
scores = self.compute_similarity(query_reps, candi_reps)
scores = scores / self.temperature
scores = scores.view(query_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (candi_reps.size(0) // query_reps.size(0))
loss_edit = self.compute_loss(scores, target)
loss = loss_edit
logging.info("task types: %s; loss: %s" %(task_type, str(loss_edit)))
else:
scores = self.compute_similarity(query_reps, candi_reps)
loss=None
return EncoderOutput(
loss=loss,
scores=scores,
q_reps=query_reps,
c_reps=candi_reps,
)
def compute_loss(self, scores, target):
return self.cross_entropy(scores, target)
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def save(self, output_dir: str):
torch.save(self.state_dict(), os.path.join(output_dir, 'Visualized_BGE.pth'))
MIT License
Copyright (c) 2022 staoxiao
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
from .cocktail import mix_models, mix_models_with_data, mix_models_by_layers
import os
import shutil
import torch
import random
import numpy as np
from typing import List, Dict, Any
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer, models
from transformers import pipeline
from .utils import load_model, get_model_param_list, merge_param, compute_weights, get_model_param_dirs, merge_param_by_layer
def save_ckpt_for_sentence_transformers(ckpt_dir, pooling_mode: str = 'cls', normalized: bool = True):
word_embedding_model = models.Transformer(ckpt_dir)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode=pooling_mode)
if normalized:
normalized_layer = models.Normalize()
model = SentenceTransformer(modules=[word_embedding_model, pooling_model, normalized_layer],
device='cpu')
else:
model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu')
model.save(ckpt_dir)
def mix_models(model_names_or_paths: List[str],
model_type: str,
weights: List[float],
output_path: str=None):
"""_summary_
mix models based on given weights
Args:
model_names_or_paths (List[str]): a list of names or paths to models
model_type (str): type of model to mix, should be in ["decoder", "encoder", "reranker"]
weights (List[float]): a list of mixing weights. The sum of weights should be equal to 1.
output_path (str, optional): path to save the mixed model. Defaults to None.
Returns:
new model
"""
assert len(model_names_or_paths) == len(weights)
assert model_type in ['decoder', 'encoder', 'reranker']
assert sum(weights) - 1 <= 1e-3
param_list = get_model_param_list(model_names_or_paths, model_type=model_type)
new_param = merge_param(param_list, weights=weights)
print("***weight for each model***: ")
for w, n in zip(weights, model_names_or_paths):
print(n, w)
model = load_model(model_names_or_paths[0], model_type=model_type)
model.load_state_dict(new_param)
if output_path is not None:
print(f"Saving the new model to {output_path}")
model.save_pretrained(output_path)
tokenizer = AutoTokenizer.from_pretrained(model_names_or_paths[0], trust_remote_code=True)
tokenizer.save_pretrained(output_path)
if model_type == "encoder":
print(f"Transform the model to the format of 'sentence_transformers' (pooling_method='cls', normalized=True)")
save_ckpt_for_sentence_transformers(ckpt_dir=output_path)
return model
def mix_models_with_data(model_names_or_paths: List[str],
model_type: str,
example_data: List[Dict],
temperature: float=5.0,
batch_size:int=2,
max_input_length:int=2048,
neg_number: int=7,
output_path: str=None):
"""_summary_
mix model based on given a few examples
Args:
model_names_or_paths (List[str]): a list of names or paths to models
model_type (str): type of model to mix, should be in ["decoder", "encoder"]
example_data (List[Any]): a list of examples
temperature (float, optional): temperature can impact the distribution of weights . Defaults to 3.0.
batch_size (int, optional): batch size to compute loss. Defaults to 2.
max_input_length (int, optional): max number of input tokens for model. Defaults to 2048.
neg_number (int, optional): the number of negatives when compute contrastive loss for embedding model. Defaults to 7.
output_path (str, optional): path to save the mixed model. Defaults to None.
Returns:
new model
"""
assert model_type in ['decoder', 'encoder', 'encoder-decoder']
model = load_model(model_names_or_paths[0], model_type=model_type)
tokenizer = AutoTokenizer.from_pretrained(model_names_or_paths[0], trust_remote_code=True)
param_list = get_model_param_list(model_names_or_paths, model_type=model_type)
weights = compute_weights(model, tokenizer=tokenizer, param_list=param_list, model_type=model_type,
example_data=example_data, temperature=temperature, neg_number=neg_number,
batch_size=batch_size, max_input_length=max_input_length)
print("***weight for each model***: ")
for w, n in zip(weights, model_names_or_paths):
print(n, w)
new_param = merge_param(param_list, weights=weights)
model.load_state_dict(new_param)
if output_path is not None:
print(f"Saving the new model to {output_path}")
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
if model_type == "encoder":
print(f"Transform the model to the format of 'sentence_transformers' (pooling_method='cls', normalized=True)")
save_ckpt_for_sentence_transformers(ckpt_dir=output_path)
return model
def mix_models_by_layers(model_names_or_paths: List[str],
model_type: str,
weights: List[float],
output_path: str=None):
"""_summary_
mix models based on given weights, and load them layer by layer
Args:
model_names_or_paths (List[str]): a list of names or paths to models
model_type (str): type of model to mix, should be in ["decoder", "encoder", "reranker"]
weights (List[float]): a list of mixing weights. The sum of weights should be equal to 1.
output_path (str, optional): path to save the mixed model. Defaults to None.
Returns:
new model
"""
assert len(model_names_or_paths) == len(weights)
assert model_type in ['decoder', 'encoder', 'reranker']
assert sum(weights) - 1 <= 1e-3
param_dirs, temp_dir = get_model_param_dirs(model_names_or_paths, model_type=model_type)
temp_file_path = merge_param_by_layer(param_dirs, weights=weights)
print("***weight for each model***: ")
for w, n in zip(weights, model_names_or_paths):
print(n, w)
with init_empty_weights():
if model_type == 'decoder':
meta_model = AutoModelForCausalLM.from_pretrained(model_names_or_paths[0], trust_remote_code=True)
elif model_type == 'encoder':
meta_model = AutoModel.from_pretrained(model_names_or_paths[0], trust_remote_code=True)
elif model_type == 'reranker':
model = AutoModelForSequenceClassification.from_pretrained(model_names_or_paths[0], trust_remote_code=True)
else:
raise NotImplementedError(f"not support this model_type: {model_type}")
device_map = {name: "cpu" for name, _ in meta_model.named_modules()}
model = load_checkpoint_and_dispatch(meta_model, checkpoint=temp_file_path, device_map=device_map)
model.tie_weights()
os.remove(temp_file_path)
shutil.rmtree(temp_dir)
print(f"Remove temporary file: {temp_file_path}")
print(f"Remove temporary directory: {temp_dir}")
if output_path is not None:
print(f"Saving the new model to {output_path}")
model.save_pretrained(output_path)
tokenizer = AutoTokenizer.from_pretrained(model_names_or_paths[0])
tokenizer.save_pretrained(output_path)
if model_type == "encoder":
print(f"Transform the model to the format of 'sentence_transformers' (pooling_method='cls', normalized=True)")
save_ckpt_for_sentence_transformers(ckpt_dir=output_path)
return model
import os
import gc
import tempfile
import torch
import random
import numpy as np
from tqdm import tqdm
from typing import List, Dict, Any
from transformers import AutoModelForCausalLM, AutoModel, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, is_torch_npu_available
def load_llm(model_name:str, trust_remote_code:bool):
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=trust_remote_code, device_map = {"": "cpu"})
return model
def load_embedder(model_name:str, trust_remote_code:bool):
model = AutoModel.from_pretrained(model_name, trust_remote_code=trust_remote_code, device_map = {"": "cpu"})
return model
def load_reranker(model_name:str, trust_remote_code:bool):
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=trust_remote_code, device_map = {"": "cpu"})
return model
def load_seq2seq_model(model_name:str, trust_remote_code:bool):
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=trust_remote_code)
return model
def load_model(model_name:str, model_type:str, trust_remote_code:bool=True):
if model_type == 'decoder':
model = load_llm(model_name, trust_remote_code=trust_remote_code)
elif model_type == 'encoder':
model = load_embedder(model_name, trust_remote_code=trust_remote_code)
elif model_type == 'reranker':
model = load_reranker(model_name, trust_remote_code=trust_remote_code)
elif model_type == 'encoder-decoder':
model = load_seq2seq_model(model_name, trust_remote_code=trust_remote_code)
else:
raise NotImplementedError(f"not support this model_type: {model_type}")
return model
def get_model_param_list(model_names: List[str], model_type:str):
model_param_list = []
for name in model_names:
print(f"loading {name} -----------------")
model = load_model(name, model_type=model_type)
model_param_list.append(model.state_dict())
return model_param_list
def merge_param(model_param_list: List[Dict], weights: List[float]):
new_param = {}
for k in model_param_list[0].keys():
for w, param in zip(weights, model_param_list):
if param[k].dtype == torch.int64 or param[k].dtype == torch.int32:
new_param[k] = param[k]
elif k not in new_param:
new_param[k] = w * param[k]
else:
new_param[k] += w * param[k]
return new_param
def get_model_param_dirs(model_names: List[str], model_type:str):
param_dirs = []
temp_dir = tempfile.mkdtemp()
print(f"create a temporary directory: {temp_dir}")
for idx, name in enumerate(model_names):
print(f"loading {name} -----------------")
model = load_model(name, model_type=model_type)
model_params = model.state_dict()
model_temp_dir = os.path.join(temp_dir, f"model_{idx+1}")
os.makedirs(model_temp_dir, exist_ok=True)
param_dirs.append(model_temp_dir)
for k, v in model_params.items():
temp_param_file = os.path.join(model_temp_dir, f"{k}.ckpt")
torch.save(v, temp_param_file)
model = model.to("meta")
del model_params
gc.collect()
return param_dirs, temp_dir
def merge_param_by_layer(model_param_dirs: List[str], weights: List[float]):
new_param = {}
model_params = os.listdir(model_param_dirs[0])
for param_file in tqdm(model_params, desc="Merging models"):
param_name = param_file.replace(".ckpt", "")
for w, model_dir in tqdm(zip(weights, model_param_dirs), total=len(weights), desc=f"Processing {param_name}", leave=False):
file_path = os.path.join(model_dir, param_file)
param = torch.load(file_path)
if param.dtype in [torch.int64, torch.int32]:
new_param[param_name] = param
elif param_name not in new_param:
new_param[param_name] = w * param
else:
new_param[param_name] += w * param
del param
gc.collect()
with tempfile.NamedTemporaryFile(delete=False, suffix=".ckpt") as tmp_file:
print(f"create a temporary file to store mixed weights: {tmp_file.name}")
torch.save(new_param, tmp_file.name)
temp_file_path = tmp_file.name
del new_param
gc.collect()
return temp_file_path
def compute_weights(base_model, tokenizer, param_list: List[Dict], model_type: str, example_data: List[Any], temperature: float=5.0, batch_size:int=2, max_input_length:int=2048, neg_number:int=7):
if torch.cuda.is_available():
device = torch.device("cuda")
elif is_torch_npu_available():
device = torch.device("npu")
else:
device = torch.device("cpu")
base_model = base_model.to(device)
if model_type == 'decoder':
input_data = preprocess_data_for_llm(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length)
loss_func = llm_loss
elif model_type == 'encoder':
input_data = preprocess_data_for_embedder(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length, neg_number=neg_number)
loss_func = embedder_loss
elif model_type == 'encoder-decoder':
input_data = preprocess_data_for_seq2seq(example_data=example_data, tokenizer=tokenizer, device=device, batch_size=batch_size, max_input_length=max_input_length)
loss_func = seq2seq_loss
example_loss = []
with torch.no_grad():
for params in param_list:
base_model.load_state_dict(params)
loss = loss_func(base_model=base_model, input_data=input_data)
example_loss.append(loss)
weights = torch.softmax(-torch.FloatTensor(example_loss)/temperature, -1).numpy().tolist()
return weights
def preprocess_data_for_seq2seq(example_data, tokenizer, device, batch_size:int=2, max_input_length:int=512): # Added Reimer
batch_data = []
for i in range(0, len(example_data), batch_size):
batch_examples = example_data[i:i+batch_size]
input_texts = [ex['input'] for ex in batch_examples]
target_texts = [ex['output'] for ex in batch_examples]
input_encodings = tokenizer(input_texts, text_target=target_texts, max_length=max_input_length, padding=True, truncation=True, return_tensors="pt")
input_ids = input_encodings.input_ids.to(device)
attention_mask = input_encodings.attention_mask.to(device)
labels = input_encodings.labels.to(device)
labels[labels == tokenizer.pad_token_id] = -100
batch_data.append({
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
})
return batch_data
def preprocess_data_for_embedder(example_data, tokenizer, device, batch_size:int=64, max_input_length:int=512, neg_number:int=7):
input_data = []
quries = []
passages = []
# max_input_length = min(512, max_input_length)
for e in example_data:
quries.append(e['query'])
passages.append(e['pos'][0])
passages.extend(random.sample(e['neg'], neg_number))
if len(quries) == batch_size:
q_tokens = tokenizer(quries, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
p_tokens = tokenizer(passages, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
q_tokens, p_tokens = q_tokens.to(device), p_tokens.to(device)
input_data.append([q_tokens, p_tokens])
quries, passages = [], []
if len(quries) > 0:
q_tokens = tokenizer(quries, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
p_tokens = tokenizer(passages, padding=True, truncation=True, max_length=max_input_length, return_tensors="pt")
q_tokens, p_tokens = q_tokens.to(device), p_tokens.to(device)
input_data.append([q_tokens, p_tokens])
return input_data
def seq2seq_loss(base_model, input_data):
total_loss = 0
with torch.no_grad():
for batch in input_data:
outputs = base_model(input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
total_loss += outputs.loss.cpu()
average_loss = total_loss / len(input_data)
return float(average_loss)
def embedder_loss(base_model, input_data):
def generate_embeddings(model, inputs):
embeddings = model(**inputs, return_dict=True).last_hidden_state[:, 0]
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
return embeddings
with torch.no_grad():
loss = 0
for q_inputs, p_inputs in input_data:
q_embeddings = generate_embeddings(base_model, q_inputs)
p_embeddings = generate_embeddings(base_model, p_inputs)
scores = torch.matmul(q_embeddings, p_embeddings.transpose(0, 1)) / 0.05
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (p_embeddings.size(0) // q_embeddings.size(0))
batch_loss = torch.nn.CrossEntropyLoss(reduction='mean')(scores, target)
loss += batch_loss.cpu()
loss = float(loss / len(input_data))
return float(loss)
def preprocess_data_for_llm(example_data, tokenizer, device, batch_size:int=2, max_input_length:int=2048):
batch_input_ids = []
batch_labels = []
batch_max_length = max_input_length
for data in example_data:
input, output = data['input'], data['output']
input_ids = tokenizer.encode(input+' '+output)
input_ids.append(tokenizer.eos_token_id)
prompt_ids = tokenizer.encode(input)
labels = [-100]*len(prompt_ids) + input_ids[len(prompt_ids):]
input_ids = input_ids[:batch_max_length]
input_ids += [tokenizer.pad_token_id] * (batch_max_length - len(input_ids))
batch_input_ids.append(input_ids)
labels = labels[:batch_max_length]
labels += [-100] * (batch_max_length - len(labels))
batch_labels.append(labels)
batch_input_ids = torch.LongTensor(batch_input_ids).to(device)
batch_labels = torch.LongTensor(batch_labels).to(device)
attention_mask = batch_input_ids.ne(tokenizer.pad_token_id).to(device)
batch_data = []
for i in range(0, len(batch_input_ids), batch_size):
batch_data.append(dict(
input_ids=batch_input_ids[i:i+batch_size],
labels=batch_labels[i:i+batch_size],
attention_mask=attention_mask[i:i+batch_size],
))
return batch_data
def llm_loss(base_model, input_data):
loss = 0
with torch.no_grad():
for data in input_data:
output = base_model(**data)
loss += output.loss.cpu()
loss = float(loss / len(input_data))
return loss
<div align="center">
<h1> <a href="https://arxiv.org/abs/2311.13534">LM-Cocktail: Resilient Tuning of Language Models via Model Merging</a> </h1>
<img src="images/1.png" width="30%" class="center">
</div>
**Make fine-tuning of language models akin to crafting a nuanced cocktail.**
Model merging can be used to improve the performance of single model.
We find this method is also useful for large language models and dense embedding model,
and design the LM-Cocktail strategy which automatically merges fine-tuned models and base model using a simple function to compute merging weights.
LM-Cocktail can be used to improve the performance on target domain without decrease
the general capabilities beyond target domain.
It also can be used to generate a model for new tasks without fine-tuning.
For more details please refer to our report: [LM-Cocktail](https://arxiv.org/abs/2311.13534).
## Application
The following are some application scenarios (Note that the models used to merge need to have the same architecture and the same initialization parameter):
### 1. Mitigate the problem of Catastrophic Forgetting
Fine-tuning the base language model could lead to severe degeneration of model’s general capabilities beyond the targeted domain.
By mixing the fine-tuned model and the base model (use function `mix_models`), LM-Cocktail can significantly enhance performance in downstream task
while maintaining performance in other unrelated tasks.
If there are some available models fine-tuned on other tasks, you can further use them to enhance your fine-tuned model.
Firstly, you need to collect five example data from your task, then employ function `mix_models_with_data` to compute weights and merge available models.
In this way, it can assign lower weights to low-quality models, avoiding degrading the performance on your task.
Finally, use `mix_models` to merge produced model and your fine-tuned model.
### 2. Improve the performance of new task without fine-tuning
LM-Cocktail can improve the accuracy of the new task without a requisition to fine-tune a model.
Give a few examples data (e.g., five examples),
and some available models (from open-source community or pre-existing for other tasks),
function `mix_models_wit_data` can automatically assign different merging weights for different model
based their loss in example data, and then merge these available models to generate a task-specific new model.
### 3. Approximate multitask learning
If you have some models who are fine-tune on different tasks, you can merge them into one model to approximate multitask learning.
The merged model can be used to perform multiple tasks.
## Usage
Install the latest version from source (Recommended):
```bash
git clone https://github.com/FlagOpen/FlagEmbedding.git
cd FlagEmbedding/LM_Cocktail
pip install -e .
```
Install by pip:
```bash
pip install -U LM_Cocktail
```
There are two key functions in LM-Cocktail:
### 1. Mix models
`mix_models` can merge models based on the given merging weights.
An example is merging the fine-tuned model and
the base model to mitigate Catastrophic Forgetting after fine-tuning:
```python
from LM_Cocktail import mix_models, mix_models_with_data
# mix LLMs and save it to output_path: ./mixed_model_1
model = mix_models(
model_names_or_paths=["meta-llama/Llama-2-7b-chat-hf", "Shitao/llama2-ag-news"],
model_type='decoder',
weights=[0.7, 0.3],
output_path='./mixed_llm')
# you can select a weight for your models to get a trade-off between generality and expertise.
# Mix Embedding Models
model = mix_models(
model_names_or_paths=["BAAI/bge-base-en-v1.5", "Shitao/bge-hotpotqa"],
model_type='encoder',
weights=[0.5, 0.5],
output_path='./mixed_embedder')
# Mix reranker Models
model = mix_models(
model_names_or_paths=["BAAI/bge-reranker-base", "BAAI/bge-reranker-base"],
model_type='reranker',
weights=[0.5, 0.5],
output_path="./mixed_reranker")
```
Note that the sum of weights should be equal to 1.
You also can merge multiple models:
```python
from LM_Cocktail import mix_models, mix_models_with_data
model = mix_models(
model_names_or_paths=["BAAI/bge-base-en-v1.5", "Shitao/bge-hotpotqa", "Shitao/bge-quora", "Shitao/bge-msmarco"],
model_type='encoder',
weights=[0.3, 0.2, 0.2, 0.3],
output_path='./mixed_embedder_2')
# The sum of weights should be equal to 1.
```
### 2. Mix models with weights computed based on a few examples
`mix_models_with_data` can compute merging weights based on given data and merge models.
It can be used to produce a model for a new task without training,
or boost the performance for the downstream task by leveraging the knowledge in others models.
- For LLMs
The format of `example_data` for LLMs is a list, where each item is a dict like:
```
{"input": str, "output": str}
```
LM-cocktial will compute the loss of the output.
You can use the example data to merge models as following:
```python
from LM_Cocktail import mix_models, mix_models_with_data
example_data = [
{"input": "Question: when was the last time anyone was on the moon? Answer:\n", "output": "14 December 1972 UTC"},
{"input": "Review: \"it 's a charming and often affecting journey . \" Is this movie review sentence negative or positive?\n", "output": "Positive"}
]
model = mix_models_with_data(
model_names_or_paths=["meta-llama/Llama-2-7b-chat-hf", "Shitao/llama2-ag-news", "Shitao/llama2-nq"],
model_type='decoder',
example_data=example_data,
temperature=5.0)
# you can set the temperature argument to adjust the distribution of mixing weights
```
- For Embedder
The format of `example_data` for LLMs is a list, where each item is a dict like:
```
{"query": str, "pos": List[str], 'neg': List[str]}
```
where pos is a list of positive text and neg is a list of negative text. LM-Cocktail will compute the contrastive loss.
You can use the example data to merge models as following:
```python
from LM_Cocktail import mix_models, mix_models_with_data
example_data = [
{"query": "How does one become an actor in the Telugu Film Industry?", "pos": [" How do I become an actor in Telugu film industry?"], "neg": [" What is the story of Moses and Ramesses?", " Does caste system affect economic growth of India?"]},
{"query": "Why do some computer programmers develop amazing software or new concepts, while some are stuck with basic programming work?", "pos": [" Why do some computer programmers develops amazing softwares or new concepts, while some are stuck with basics programming works?"], "neg": [" When visiting a friend, do you ever think about what would happen if you did something wildly inappropriate like punch them or destroy their furniture?", " What is the difference between a compliment and flirting?"]}
]
model = mix_models_with_data(
model_names_or_paths=["BAAI/bge-base-en-v1.5", "Shitao/bge-hotpotqa", "Shitao/bge-quora"],
model_type='encoder',
example_data=example_data,
temperature=5.0,
max_input_length=512,
neg_number=2)
```
### 3. Mix models layer by layer for reducing memory cost
The function `mix_models_by_layers` creates temporary directories to store weights of individual models and then merges them layer by layer.
This approach helps in reducing the memory consumption.
Once the merging process is completed, the temporary directories and files will be automatically removed.
```python
from LM_Cocktail import mix_models_by_layers
# Mix Large Language Models (LLMs) and save the combined model to the path: ./mixed_llm
model = mix_models_by_layers(
model_names_or_paths=["meta-llama/Llama-2-7b-chat-hf", "Shitao/llama2-ag-news"],
model_type='decoder',
weights=[0.7, 0.3],
output_path='./mixed_llm')
```
## Performance
Detailed results please refer to our report: [LM-Cocktail](https://arxiv.org/abs/2311.13534)
- LM-Cocktail for Catastrophic Forgetting
| Model | Target Task | Others(29 tasks) |
|:---------------------------|:--------:|:----------------:|
| Llama | 40.8 | 46.8 |
| Fine-tuned | 94.4 | 38.6 |
| LM-Cocktail(2 models) [1] | 94.5 | 47.7 |
| LM-Cocktail(10 models) [2] | 94.4 | 48.3 |
[1]: merge 2 models: fine-tuned model and the base model
[2]: merge 10 models based on five examples: fine-tuned model, the base model, and 8 models fine-tuned on other tasks
| Model | Target Task | Other Tasks(14 tasks) |
|:-------------------------------|:--------:|:---------------------:|
| BGE | 71.8 | 49.8 |
| Fine-tuned | 76.0 | 48.5 |
| LM-Cocktail(2 models) | 74.8 | 50.0 |
| LM-Cocktail(10 models) | 74.7 | 50.6 |
- LM-Cocktail for new tasks without fine-tuning
Merge 10 models fine-tuned on other tasks based on five examples for new tasks:
| Model | MMLU(57 tasks) |
|:-------------------------------|:--------------:|
| Llama | 45.9 |
| Llama-5shot | 46.7 |
| LM-Cocktail(10 models) | 48.0 |
| Model | Retrieval(12 tasks) |
|:-------------------------------|:-------------------:|
| BGE | 47.3 |
| LM-Cocktail(10 models) | 48.8 |
## Evaluation
### 1. Reproduce the results of LLM
- Models: we fine-tune the [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) on 9 tasks, and you can find the fine-tuned models at this [link](https://huggingface.co/Shitao). Note that the most of fine-tuned models has a poor performance on other unrelated tasks.
- Examples Data for dataset from FLAN: [./llm_examples.json]()
- MMLU dataset: https://huggingface.co/datasets/cais/mmlu (use the example in dev set to do in-context learning)
You can use these models and our code to produce a new model and evaluate its performance using the [llm-embedder script](https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/llm_embedder/docs/evaluation.md) as following:
```
# for 30 tasks from FLAN
torchrun --nproc_per_node 8 -m evaluation.eval_icl \
--retrieval_method no \
--few_shot 0 \
--data_root /data/llm-embedder \
--model_name_or_path ./mixed_model_1
# for MMLU datasets
torchrun --nproc_per_node 8 -m evaluation.eval_mmlu \
--retrieval_method no \
--few_shot 0 \
--data_root /data/llm-embedder \
--model_name_or_path ./mixed_model_2
```
### 2. Reproduce the results of Embedding Model
- Models: we fine-tune the [bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) on 9 tasks, and you can find the fine-tuned models at this [link](https://huggingface.co/Shitao).
- Examples Data: [./embedder_examples.json]()
Use [MTEB script](https://github.com/FlagOpen/FlagEmbedding/tree/master/C_MTEB) to evaluate the mixed embedding model:
```bash
python eval_MTEB.py --model_name_or_path mixed_model --task_type Retrieval
```
## Acknowledgement
This project is inspired by previous researches on model merging, including [LoraHub](https://github.com/sail-sg/lorahub), [model soups](https://github.com/mlfoundations/model-soups), and [PAINT](https://github.com/mlfoundations/patching) .
The Llama is fine-tuned using the [FastChat](https://github.com/lm-sys/FastChat) scripts.
Fine-tuning datasets are from [sentence-transformers/embedding-training-data](https://huggingface.co/datasets/sentence-transformers/embedding-training-data) and [intfloat/llm-retriever-tasks](https://huggingface.co/datasets/intfloat/llm-retriever-tasks).
## Citation
If you find this repository useful, please consider giving a star :star: and citation
```
@misc{cocktail,
title={LM-Cocktail: Resilient Tuning of Language Models via Model Merging},
author={Shitao Xiao and Zheng Liu and Peitian Zhang and Xingrun Xing},
year={2023},
eprint={2311.13534},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
This source diff could not be displayed because it is too large. You can view the blob instead.
{"mnli_m": [{"input": "Premise: \"The new rights are nice enough\" Hypothesis: \"Everyone really likes the newest benefits \" Does the premise entail the hypothesis? Yes, No, or Maybe?\n", "output": "Maybe"}, {"input": "Premise: \"This site includes a list of all award winners and a searchable database of Government Executive articles.\" Hypothesis: \"The Government Executive articles housed on the website are not able to be searched.\" Does the premise entail the hypothesis? Yes, No, or Maybe?\n", "output": "No"}, {"input": "Premise: \"uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him\" Hypothesis: \"I like him for the most part, but would still enjoy seeing someone beat him.\" Does the premise entail the hypothesis? Yes, No, or Maybe?\n", "output": "Yes"}, {"input": "Premise: \"yeah i i think my favorite restaurant is always been the one closest you know the closest as long as it's it meets the minimum criteria you know of good food\" Hypothesis: \"My favorite restaurants are always at least a hundred miles away from my house. \" Does the premise entail the hypothesis? Yes, No, or Maybe?\n", "output": "No"}, {"input": "Premise: \"i don't know um do you do a lot of camping\" Hypothesis: \"I know exactly.\" Does the premise entail the hypothesis? Yes, No, or Maybe?\n", "output": "No"}], "mrpc": [{"input": "Here are two sentences: He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . \" The foodservice pie business does not fit our long-term growth strategy . Do they have the same meaning?\n", "output": "Yes"}, {"input": "Here are two sentences: Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was \" 100 percent behind George Bush \" and looked forward to using his years of training in the war . Do they have the same meaning?\n", "output": "No"}, {"input": "Here are two sentences: The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent . Do they have the same meaning?\n", "output": "No"}, {"input": "Here are two sentences: The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries . Do they have the same meaning?\n", "output": "Yes"}, {"input": "Here are two sentences: No dates have been set for the civil or the criminal trial . No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty . Do they have the same meaning?\n", "output": "No"}], "natural_questions": [{"input": "Question: when was the last time anyone was on the moon? Answer:\n", "output": "14 December 1972 UTC"}, {"input": "Question: who wrote he ain't heavy he's my brother lyrics? Answer:\n", "output": "Bobby Scott"}, {"input": "Question: how many seasons of the bastard executioner are there? Answer:\n", "output": "one"}, {"input": "Question: when did the eagles win last super bowl? Answer:\n", "output": "2017"}, {"input": "Question: who won last year's ncaa women's basketball? Answer:\n", "output": "South Carolina"}], "squad_v1": [{"input": "Please answer a question about the following article about Super Bowl 50: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team represented the AFC at Super Bowl 50?\n", "output": "Denver Broncos"}, {"input": "Please answer a question about the following article about Super Bowl 50: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team represented the NFC at Super Bowl 50?\n", "output": "Carolina Panthers"}, {"input": "Please answer a question about the following article about Super Bowl 50: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50. Where did Super Bowl 50 take place?\n", "output": "Santa Clara, California"}, {"input": "Please answer a question about the following article about Super Bowl 50: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team won Super Bowl 50?\n", "output": "Denver Broncos"}, {"input": "Please answer a question about the following article about Super Bowl 50: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50. What color was used to emphasize the 50th anniversary of the Super Bowl?\n", "output": "gold"}], "sst2": [{"input": "Review: \"it 's a charming and often affecting journey . \" Is this movie review sentence negative or positive?\n", "output": "Positive"}, {"input": "Review: \"unflinchingly bleak and desperate \" Is this movie review sentence negative or positive?\n", "output": "Negative"}, {"input": "Review: \"allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker . \" Is this movie review sentence negative or positive?\n", "output": "Positive"}, {"input": "Review: \"the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales . \" Is this movie review sentence negative or positive?\n", "output": "Positive"}, {"input": "Review: \"it 's slow -- very , very slow . \" Is this movie review sentence negative or positive?\n", "output": "Negative"}], "winogrande": [{"input": "How does the sentence end? Sarah was a much better surgeon than Maria so\n", "output": "Maria always got the easier cases."}, {"input": "How does the sentence end? Sarah was a much better surgeon than Maria so\n", "output": "Sarah always got the harder cases."}, {"input": "How does the sentence end? They were worried the wine would ruin the bed and the blanket, but the\n", "output": "bed was't ruined."}, {"input": "How does the sentence end? Terry tried to bake the eggplant in the toaster oven but the\n", "output": "eggplant was too big."}, {"input": "How does the sentence end? At night, Jeffrey always stays up later than Hunter to watch TV because\n", "output": "Jeffrey wakes up late."}], "ag_news": [{"input": "\"Fears for T N pension after talks Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.\" What is this text about? World, Sports, Business, or Technology?\n", "output": "Business"}, {"input": "\"The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\\team of rocketeers competing for the #36;10 million Ansari X Prize, a contest for\\privately funded suborbital space flight, has officially announced the first\\launch date for its manned rocket.\" What is this text about? World, Sports, Business, or Technology?\n", "output": "Technology"}, {"input": "\"Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.\" What is this text about? World, Sports, Business, or Technology?\n", "output": "Technology"}, {"input": "\"Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar.\" What is this text about? World, Sports, Business, or Technology?\n", "output": "Technology"}, {"input": "\"Calif. Aims to Limit Farm-Related Smog (AP) AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure.\" What is this text about? World, Sports, Business, or Technology?\n", "output": "Technology"}], "common_gen": [{"input": "Concepts: field, look, stand. Write a sentence that includes all these words.\n", "output": "The player stood in the field looking at the batter."}, {"input": "Concepts: field, look, stand. Write a sentence that includes all these words.\n", "output": "The coach stands along the field, looking at the goalkeeper."}, {"input": "Concepts: field, look, stand. Write a sentence that includes all these words.\n", "output": "I stood and looked across the field, peacefully."}, {"input": "Concepts: field, look, stand. Write a sentence that includes all these words.\n", "output": "Someone stands, looking around the empty field."}, {"input": "Concepts: kid, room, dance. Write a sentence that includes all these words.\n", "output": "The silly kid loves to dance in her room."}], "hellaswag": [{"input": "What happens next in this paragraph? A man is sitting on a roof. he\n", "output": "starts pulling up roofing on a roof."}, {"input": "What happens next in this paragraph? A lady walks to a barbell. She bends down and grabs the pole. the lady\n", "output": "stands and lifts the weight over her head."}, {"input": "What happens next in this paragraph? Two women in a child are shown in a canoe while a man pulls the canoe while standing in the water, with other individuals visible in the background. the child and a different man\n", "output": "sit in a canoe while the man paddles."}, {"input": "What happens next in this paragraph? A boy is running down a track. the boy\n", "output": "lifts his body above the height of a pole."}, {"input": "What happens next in this paragraph? The boy lifts his body above the height of a pole. The boy lands on his back on to a red mat. the boy\n", "output": "gets up from the mat."}]}
\ No newline at end of file
from setuptools import setup, find_packages
with open("README.md", mode="r", encoding="utf-8") as readme_file:
readme = readme_file.read()
setup(
name='LM_Cocktail',
version='0.0.5',
description='LM_Cocktail',
long_description=readme,
long_description_content_type="text/markdown",
author_email='2906698981@qq.com',
url='https://github.com/FlagOpen/FlagEmbedding/LM_Cocktail',
packages=find_packages(),
install_requires=[
'torch>=1.6.0',
'transformers>=4.18.0',
'datasets',
'accelerate>=0.20.1'
],
)
outputs
results
pretrain
\ No newline at end of file
<div align="center">
<h1>Soaring from 4K to 400K: Extending LLM's Context with Activation Beacon [<a href="https://arxiv.org/abs/2401.03462">paper</a>]</h1>
</div>
This is the codebase for Activation Beacon, an effective, efficient, compatible, and low-cost (training) method to extend the context length of LLM through compressing KV cache.
## File structure:
- The [old](./old/) folder contains our initial implementation of Activation Beacon for Llama-2. You can use the code in it to reproduce the training/evaluation of the Llama-2 based model shown in our paper.
- The [new](./new/) folder contains **newer** implementation of Activation Beacon. It supports more LLMs, including Mistral, Llama-3, and Qwen-2. It also supports more features, including **Deepspeed Zero3 training**, **Flash-Attention-2**, adding **chat template** in training and inference, and **evaluating on more tasks**. However, code in this folder are under development and subject to change in the future.
# Activation-Beacon
[Activation Beacon](https://arxiv.org/abs/2401.03462) compresses the original KV into fewer yet more compact states (a.k.a. beacons) and hence enables the LLM to perceive longer context given its fixed context window. It is known for the following features:
- **Effective**
- there is little information loss given a compression ratio of 2, 4, and 8;
- **Efficient**
- it drastically reduces the GPU consumption of KV cache;
- **Compatible**
- it can work together with position extrapolation (e.g. YaRN) to further extends the context length; it can also work with grouped query attention to further reduce the KV cache size;
- **Low-Cost**
- it is light-weight and can be efficiently trained with roughly 1B tokens.
This folder contains the newer code for activation beacon. It supports more LLMs, including Mistral, Llama-3, and Qwen-2. It also supports more features, including **Deepspeed Zero3 training**, **Flash-Attention-2**, adding **chat template** in training and inference, and **evaluating on more tasks**. However, code in this folder are under development and subject to change in the future.
## Environment
```bash
conda create beacon python=3.10.14
conda activate beacon
# You may need to adjust the cuda version
conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers==4.39.3 deepspeed accelerate datasets peft pandas seaborn rouge fuzzywuzzy jieba python-Levenshtein
pip install flash-attn --no-build-isolation
```
## Usage
```python
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "namespace-Pt/beacon-qwen-2-7b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
model = model.cuda().eval()
with torch.no_grad():
# short context
messages = [{"role": "user", "content": "Tell me about yourself."}]
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=50)
print(f"Input Length: {inputs['input_ids'].shape[1]}")
print(f"Output: {repr(tokenizer.decode(outputs[0], skip_special_tokens=True))}")
# reset memory before new generation task
model.memory.reset()
# long context
with open("data/toy/infbench.json", encoding="utf-8") as f:
example = json.load(f)
messages = [{"role": "user", "content": example["context"]}]
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda")
outputs = model.generate(**inputs, do_sample=False, top_p=1, temperature=1, max_new_tokens=20)[:, inputs["input_ids"].shape[1]:]
print("*"*20)
print(f"Input Length: {inputs['input_ids'].shape[1]}")
print(f"Answers: {example['answer']}")
print(f"Prediction: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
```
**NOTE**: It's okay to see warnings like `This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (32768). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.` Just ignore it.
## Data
You should download the data for fine-tuning & evaluation then untar the file at anywhere you prefer, e.g. `/data`:
```bash
# feel free to alternate /data to your prefered location
wget https://huggingface.co/datasets/namespace-Pt/projects/resolve/main/long-llm.tar.gz?download=true -O /data/long-llm.tar.gz
cd /data
tar -xzvf long-llm.tar.gz
```
**IMPORTANT NOTE**
For any path specified for `train_data` and `eval_data`: if it is prefixed with `long-llm:`, it will be solved to the relative path against [`data_root`](./src/args.py).
- e.g. `long-llm:lm/pg19.json` becomes `${data_root}/lm/pg19.json`
- you can modify the default value of [`data_root`](./src/args.py), so that you don't need to type it for each command.
## Training
See [training section](./docs/training.md).
## Evaluation
See [evaluation section](./docs/evaluation.md).
## Citation
If you find this repository useful, please give us a star ⭐.
To cite our work:
```
@misc{zhang2024soaring,
title={Soaring from 4K to 400K: Extending LLM's Context with Activation Beacon},
author={Peitian Zhang and Zheng Liu and Shitao Xiao and Ninglu Shao and Qiwei Ye and Zhicheng Dou},
year={2024},
eprint={2401.03462},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
\ No newline at end of file
{
"mixture": {
"commoncrawl": 10,
"c4": 10,
"github": 25,
"book": 10,
"arxiv": 10,
"wiki": 10,
"stackexchange": 25
},
"num_tokens_avg": {
"commoncrawl": 1207,
"c4": 378,
"wiki": 393,
"stackexchange": 309,
"github": 436,
"book": 89373,
"arxiv": 7375
}
}
\ No newline at end of file
{
"mixture": {
"commoncrawl": 14.2,
"c4": 14.2,
"github": 14.2,
"book": 14.2,
"arxiv": 14.2,
"wiki": 14.2,
"stackexchange": 14.2
},
"num_tokens_avg": {
"commoncrawl": 1207,
"c4": 378,
"wiki": 393,
"stackexchange": 309,
"github": 436,
"book": 89373,
"arxiv": 7375
}
}
\ No newline at end of file
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment