Unverified Commit 53c65300 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

[SGLANG] Add the SGLANG generate API (#2997)

* add `sglang-generate`

* nit

* nit

* nit

* pacify pre-commit
parent 0daf28fd
...@@ -401,7 +401,7 @@ class Task(abc.ABC): ...@@ -401,7 +401,7 @@ class Task(abc.ABC):
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Optional[Callable] = None,
tokenizer_name: str = "", tokenizer_name: str = "",
question_suffix: str = "" question_suffix: str = "",
) -> None: ) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
...@@ -1077,13 +1077,23 @@ class ConfigurableTask(Task): ...@@ -1077,13 +1077,23 @@ class ConfigurableTask(Task):
if not fewshot_as_multiturn: if not fewshot_as_multiturn:
# if no messages or last message is system, append as new user entry # if no messages or last message is system, append as new user entry
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system": if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
labeled_examples.append({"role": "user", "content": question + question_suffix} if question_suffix else {"role": "user", "content": question} ) labeled_examples.append(
{"role": "user", "content": question + question_suffix}
if question_suffix
else {"role": "user", "content": question}
)
# if last message is user, append to it to avoid two user messages in a row # if last message is user, append to it to avoid two user messages in a row
else: else:
labeled_examples[-1]["content"] += question + question_suffix if question_suffix else question labeled_examples[-1]["content"] += (
question + question_suffix if question_suffix else question
)
else: else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant) # if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question + question_suffix} if question_suffix else {"role": "user", "content": question} ) labeled_examples.append(
{"role": "user", "content": question + question_suffix}
if question_suffix
else {"role": "user", "content": question}
)
if gen_prefix: if gen_prefix:
labeled_examples.append({"role": "assistant", "content": gen_prefix}) labeled_examples.append({"role": "assistant", "content": gen_prefix})
......
...@@ -16,6 +16,7 @@ from . import ( ...@@ -16,6 +16,7 @@ from . import (
optimum_ipex, optimum_ipex,
optimum_lm, optimum_lm,
sglang_causallms, sglang_causallms,
sglang_generate_API,
textsynth, textsynth,
vllm_causallms, vllm_causallms,
vllm_vlms, vllm_vlms,
......
from typing import Dict, List, Optional, Tuple, Union
from lm_eval.api.registry import register_model
from lm_eval.models.openai_completions import LocalCompletionsAPI
from lm_eval.models.utils import handle_stop_sequences
@register_model("sglang-generate")
class SGLANGGENERATEAPI(LocalCompletionsAPI):
def __init__(
self,
base_url=None,
tokenizer_backend="huggingface",
**kwargs,
):
super().__init__(
base_url=base_url, tokenizer_backend=tokenizer_backend, **kwargs
)
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
generate=False,
gen_kwargs: Optional[dict] = None,
seed: int = 1234,
eos=None,
**kwargs,
) -> dict:
is_string = (
True
if (isinstance(messages, str) or isinstance(messages[0], str))
else False
)
if generate:
gen_kwargs.pop("do_sample", False)
if "max_tokens" in gen_kwargs:
max_tokens = gen_kwargs.pop("max_tokens")
else:
max_tokens = gen_kwargs.pop("max_gen_toks", self._max_gen_toks)
temperature = gen_kwargs.pop("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.pop("until", None), eos)
request = {
"sampling_params": {
"max_new_tokens": max_tokens,
"temperature": temperature,
"stop": stop,
**gen_kwargs,
},
}
request.update({"text": messages}) if is_string else request.update(
{"input_ids": messages}
)
return request
else:
assert not is_string, "Logprobs are only supported for tokenized inputs"
request = {
"input_ids": messages,
"sampling_params": {"max_new_tokens": 1, "temperature": 0},
"logprob_start_len": 0,
"top_logprobs_num": 1,
"return_logprob": True,
}
return request
@staticmethod
def parse_logprobs(
outputs: Union[Dict, List[Dict]],
tokens: List[List[int]] = None,
ctxlens: List[int] = None,
**kwargs,
) -> List[Tuple[float, bool]]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for choice, ctxlen in zip(outputs, ctxlens):
choice = choice["meta_info"]
assert ctxlen > 0, "Context length must be greater than 0"
logprobs = sum(x[0] for x in choice["input_token_logprobs"][ctxlen:])
is_greedy = all(
x[1] != y[0][1]
for x, y in zip(
choice["input_token_logprobs"][ctxlen:],
choice["input_top_logprobs"][ctxlen:],
)
)
res.append((logprobs, is_greedy))
return res
@staticmethod
def parse_generations(outputs: Union[Dict, List[Dict]], **kwargs) -> List[str]:
res = []
if not isinstance(outputs, list):
outputs = [outputs]
for out in outputs:
res.append(out["text"])
return res
@property
def api_key(self):
return ""
...@@ -487,7 +487,10 @@ class VLLM(TemplateLM): ...@@ -487,7 +487,10 @@ class VLLM(TemplateLM):
inputs = [] inputs = []
ctxlens = [] ctxlens = []
for cache_key, context_enc, continuation_enc in chunk: for cache_key, context_enc, continuation_enc in chunk:
if full_length := len(context_enc + continuation_enc) >= self.max_length: if (
full_length := len(context_enc + continuation_enc)
>= self.max_length
):
eval_logger.warning( eval_logger.warning(
f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context." f"Context length {full_length} exceeds max length ({self.max_length}). Truncating context."
) )
......
...@@ -14,13 +14,13 @@ This is the processed version of Google's C4 dataset. ...@@ -14,13 +14,13 @@ This is the processed version of Google's C4 dataset.
```text ```text
@misc{raffel2023exploringlimitstransferlearning, @misc{raffel2023exploringlimitstransferlearning,
title={Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer}, title={Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer},
author={Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu}, author={Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu},
year={2023}, year={2023},
eprint={1910.10683}, eprint={1910.10683},
archivePrefix={arXiv}, archivePrefix={arXiv},
primaryClass={cs.LG}, primaryClass={cs.LG},
url={https://arxiv.org/abs/1910.10683}, url={https://arxiv.org/abs/1910.10683},
} }
``` ```
......
...@@ -21,4 +21,4 @@ dataset_kwargs: ...@@ -21,4 +21,4 @@ dataset_kwargs:
validation: en/c4-validation.00000-of-00008.json.gz validation: en/c4-validation.00000-of-00008.json.gz
# following the choice of https://arxiv.org/abs/2410.07461 # following the choice of https://arxiv.org/abs/2410.07461
trust_remote_code: true trust_remote_code: true
verification_mode: "no_checks" verification_mode: "no_checks"
\ No newline at end of file
from functools import partial from functools import partial
choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"] choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
def format_cot_example(example, including_answer=True): def format_cot_example(example, including_answer=True):
prompt = "Question:\n" prompt = "Question:\n"
question = example["question"] question = example["question"]
...@@ -21,15 +23,18 @@ def format_cot_example(example, including_answer=True): ...@@ -21,15 +23,18 @@ def format_cot_example(example, including_answer=True):
prompt += cot_content + "\n\n" prompt += cot_content + "\n\n"
else: else:
prompt += "Answer: Let's think step by step." prompt += "Answer: Let's think step by step."
return prompt return prompt
doc_to_text = partial(format_cot_example, including_answer=False) doc_to_text = partial(format_cot_example, including_answer=False)
fewshot_to_text = partial(format_cot_example, including_answer=True) fewshot_to_text = partial(format_cot_example, including_answer=True)
def process_docs(dataset, subject): def process_docs(dataset, subject):
return dataset.filter(lambda x: x["category"] == subject) return dataset.filter(lambda x: x["category"] == subject)
process_biology = partial(process_docs, subject="biology") process_biology = partial(process_docs, subject="biology")
process_business = partial(process_docs, subject="business") process_business = partial(process_docs, subject="business")
process_chemistry = partial(process_docs, subject="chemistry") process_chemistry = partial(process_docs, subject="chemistry")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment