You need to sign in or sign up before continuing.
Unverified Commit 7199acc2 authored by Yuan Feng's avatar Yuan Feng Committed by GitHub
Browse files

Add support for DataCanvas Alaya LM (#612)

* Support for Alaya

* Remove useless requirements
parent dbacd363
from mmengine.config import read_base
with read_base():
from .datasets.ceval.ceval_gen import ceval_datasets
from .datasets.cmmlu.cmmlu_gen import cmmlu_datasets
from .datasets.agieval.agieval_gen import agieval_datasets
from .datasets.bbh.bbh_gen import bbh_datasets
from .datasets.mmlu.mmlu_gen import mmlu_datasets
from .models.alaya.alaya import models
datasets = [*bbh_datasets, *ceval_datasets, *cmmlu_datasets, *agieval_datasets, *mmlu_datasets]
from opencompass.models import AlayaLM
models = [
dict(
type=AlayaLM,
abbr='alaya-7b-hf',
path="DataCanvas/Alaya-7B-Base",
tokenizer_path='DataCanvas/Alaya-7B-Base',
tokenizer_kwargs=dict(padding_side='left',
truncation_side='left',
trust_remote_code=True,
use_fast=False,),
max_out_len=100,
max_seq_len=2048,
batch_size=8,
model_kwargs=dict(device_map='auto', trust_remote_code=True),
run_cfg=dict(num_gpus=1, num_procs=1))
]
from .alaya import AlayaLM # noqa: F401
from .base import BaseModel, LMTemplateParser # noqa from .base import BaseModel, LMTemplateParser # noqa
from .base_api import APITemplateParser, BaseAPIModel # noqa from .base_api import APITemplateParser, BaseAPIModel # noqa
from .claude_api import Claude # noqa: F401 from .claude_api import Claude # noqa: F401
......
from typing import Dict, List, Optional, Union
import torch
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
pipeline)
from opencompass.utils.prompt import PromptList
from .base import BaseModel, LMTemplateParser
PromptType = Union[PromptList, str]
class AlayaLM(BaseModel):
"""Model wrapper for Alaya model.
Args:
path (str): The name or path to Alaya model, could be a local path
or a Huggingface model tag of Alaya.
max_seq_len (int): The maximum length of the input sequence. Defaults
to 2048.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (Dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
Note:
Alaya has some arguments which should be fixed such as
eos_token_id and bad_words_ids.
Model config should be loaded from a model config file.
Triton is supported to accelerate the inference process.
This class supports both Alaya Base model and Alaya Chat model.
"""
def __init__(self,
path: str,
max_seq_len: int = 2048,
tokenizer_only: bool = False,
meta_template: Optional[Dict] = None,
**kwargs):
self.template_parser = LMTemplateParser(meta_template)
self.max_seq_len = max_seq_len
self.tokenizer_only = tokenizer_only
self.meta_template = meta_template
self.name = path
self.eos_token_id = 2
self.bad_words_ids = 3
self.gpu_id = '0'
self.config = AutoConfig.from_pretrained(self.name,
trust_remote_code=True,
local_file_only=True)
self.config.attn_config['attn_impl'] = 'triton'
self.config.init_device = 'cuda:' + self.gpu_id
self.model = AutoModelForCausalLM.from_pretrained(
self.name,
config=self.config,
torch_dtype=torch.bfloat16, # Load model weights in bfloat16
trust_remote_code=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.name,
local_file_only=True,
padding_side='left')
self.pipe = pipeline('text-generation',
model=self.model,
tokenizer=self.tokenizer,
bad_words_ids=[[self.bad_words_ids]],
eos_token_id=self.eos_token_id,
pad_token_id=self.eos_token_id,
device='cuda:' + self.gpu_id)
def do_inference(self, instruction, history=[]):
PROMPT_FORMAT = '### Instruction:\t\n{instruction}\n\n'
OUTPUT_FORMAT = '### Output:\t\n{output} </s>'
prompt = PROMPT_FORMAT.format(instruction=instruction)
history2llm = []
for i, msg in enumerate(history):
if i % 2 == 0: # user
msg2llm = PROMPT_FORMAT.format(instruction=msg)
else: # alaya
msg2llm = OUTPUT_FORMAT.format(output=msg)
history2llm.append(msg2llm)
flag = '### Output:\t\n'
prompt2LLM = ''.join(history2llm) + prompt
if len(prompt2LLM) >= 1500:
prompt2LLM = prompt2LLM[-1500:]
result = self.pipe(prompt2LLM,
max_new_tokens=100,
max_length=1900,
do_sample=True,
use_cache=True,
eos_token_id=self.eos_token_id,
pad_token_id=self.eos_token_id)
try:
output = result[0]['generated_text'][len(prompt2LLM):].lstrip(flag)
except Exception:
output = result[0]['generated_text']
return output
def generate(
self,
inputs,
max_out_len: int = 1000,
) -> List[str]:
"""Generate results given a list of inputs."""
outputs = []
for instruction in inputs:
output = self.do_inference(instruction=instruction)
outputs.append(output)
return outputs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized string."""
return len(self.tokenizer.encode(prompt))
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
"""Copied from .huggingface.py."""
assert mask_length is None, 'mask_length is not supported'
bsz = len(inputs)
params = self.model.params
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
# tokenize
prompt_tokens = [self.tokenizer.encode(x, True, False) for x in inputs]
max_prompt_size = max([len(t) for t in prompt_tokens])
total_len = min(params.max_seq_len, max_prompt_size)
tokens = torch.zeros((bsz, total_len)).cuda().long()
for k, t in enumerate(prompt_tokens):
num_token = min(total_len, len(t))
tokens[k, :num_token] = torch.tensor(t[-num_token:]).long()
# forward
outputs = self.model.forward(tokens, 0)
# compute ppl
shift_logits = outputs[..., :-1, :].contiguous().float()
shift_labels = tokens[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=0)
loss = loss_fct(shift_logits, shift_labels).view(bsz, -1)
lens = (tokens != 0).sum(-1).cpu().numpy()
ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
return ce_loss
...@@ -5,6 +5,7 @@ cn2an ...@@ -5,6 +5,7 @@ cn2an
colossalai colossalai
cpm_kernels cpm_kernels
datasets>=2.12.0 datasets>=2.12.0
einops==0.5.0
evaluate>=0.3.0 evaluate>=0.3.0
fairscale fairscale
fuzzywuzzy fuzzywuzzy
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment