"examples/research_projects/vscode:/vscode.git/clone" did not exist on "7f0eb35af379def05c111806e12f3ef50a94a54b"
Unverified Commit 048d41a1 authored by Wang Xingjin's avatar Wang Xingjin Committed by GitHub
Browse files

add vllm get_ppl (#1003)



* add vllm get_ppl

* add vllm get_ppl

* format

---------
Co-authored-by: default avatarxingjin.wang <xingjin.wang@mihoyo.com>
Co-authored-by: default avatarLeymore <zfz-960727@163.com>
parent 3a232db4
from typing import Dict, List, Optional
import numpy as np
from opencompass.models.base import BaseModel
from opencompass.utils import get_logger
......@@ -103,6 +105,29 @@ class VLLM(BaseModel):
return output_strs
def get_ppl(self,
inputs: List[str],
mask_length: Optional[List[int]] = None) -> List[float]:
batch_size = len(inputs)
sampling_kwargs = SamplingParams(prompt_logprobs=0,
**self.generation_kwargs)
# forward
outputs = self.model.generate(inputs, sampling_kwargs)
# compute ppl
ce_loss = []
for i in range(batch_size):
prompt_logprobs = outputs[i].prompt_logprobs[1:]
prompt_token_ids = outputs[i].prompt_token_ids[1:]
prompt_logprobs_list = [
prompt_logprobs[i][prompt_token_ids[i]]
for i in range(len(prompt_logprobs))
]
prompt_logprobs_list = [i.logprob for i in prompt_logprobs_list]
prompt_logprobs_list = np.array(prompt_logprobs_list)
loss = -prompt_logprobs_list.sum(axis=-1) / len(prompt_token_ids)
ce_loss.append(loss)
return np.array(ce_loss)
def prompts_preproccess(self, inputs: List[str]):
if self.use_fastchat_template:
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment