"densecrf/include/Eigen/src/Eigen2Support/VectorBlock.h" did not exist on "3ae01ffbeb09a6cc7c7f71392e4af1ba0326a963"
Unverified Commit 44398478 authored by Slim Frikha's avatar Slim Frikha Committed by GitHub
Browse files

Ignore seed when splitting batch in chunks with groupby (#3047)



* feat(vllm_causallms): make collator ignore seed when splitting batch into chunks

* fix(collator): revert PR changes

* fix(vllm-causallm): update collator call with groupby None

* feat(sglang-causallms): make generation accept a list of sampling params

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 4f1e9f7c
......@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM):
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
re_ords = Collator(requests, _collate_gen, group_by=None)
chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
......@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM):
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
context_encoding_truncated = []
sampling_params = []
for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
if len(x) > max_ctx_len:
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
kwargs | {"max_tokens": max_gen_toks, "stop": until}
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation
# cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 .
cont = self._model_generate(
requests=context_encoding,
requests=context_encoding_truncated,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
sampling_params=sampling_params,
)
# cache generations
......@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM):
self,
requests: List[List[int]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
sampling_params: Union[List[Dict], Dict, None] = None,
return_logprob: bool = False,
top_logprobs_num: int = 1,
logprob_start_len: int = -1,
**kwargs,
):
# check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html.
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = {
"max_new_tokens": max_tokens,
"stop": stop,
}
sampling_params.update(kwargs)
else:
sampling_params = {
"temperature": 0,
"max_new_tokens": 1,
}
sampling_params.update(kwargs)
if not generate:
sampling_params = sampling_params if sampling_params else {}
sampling_params.update(
{
"temperature": 0,
"max_new_tokens": 1,
}
)
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
# Refer to: https://docs.sglang.ai/backend/offline_engine_api.html
outputs = self.model.generate(
input_ids=requests,
......
......@@ -50,7 +50,7 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker(
model_args: dict,
sampling_params: "SamplingParams",
sampling_params: "list[SamplingParams]",
requests: list[list[int]],
lora_request: "LoRARequest",
result_queue: "Queue",
......@@ -364,17 +364,14 @@ class VLLM(TemplateLM):
self,
requests: List[List[int]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
**kwargs,
sampling_params: Union[List[SamplingParams], SamplingParams, None] = None,
):
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
if not generate or sampling_params is None:
sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
)
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn
......@@ -382,7 +379,7 @@ class VLLM(TemplateLM):
@ray.remote
def run_inference_one_model(
model_args: dict,
sampling_params: SamplingParams,
sampling_params: List[SamplingParams],
requests: List[List[int]],
lora_request: LoRARequest,
):
......@@ -396,9 +393,12 @@ class VLLM(TemplateLM):
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
sampling_params = [
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
]
inputs = (
(self.model_args, sampling_params, req, self.lora_request)
for req in requests
(self.model_args, sp, req, self.lora_request)
for req, sp in zip(requests, sampling_params)
)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
......@@ -413,16 +413,18 @@ class VLLM(TemplateLM):
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests))
sampling_params = (
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
)
procs, resq = [], Queue()
# We use Process as it is non-daemonic
try:
for rank, req in enumerate(requests):
for rank, (sp, req) in enumerate(zip(requests, sampling_params)):
proc = Process(
target=_vllm_mp_worker,
args=(
self.model_args.copy(),
sampling_params,
sp,
req,
self.lora_request,
resq,
......@@ -576,10 +578,11 @@ class VLLM(TemplateLM):
# - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs")
re_ords = Collator(
requests,
_collate_gen,
group_by=None,
)
chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
)
......@@ -594,41 +597,44 @@ class VLLM(TemplateLM):
for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding]
for length in all_lengths:
if length > max_ctx_len:
context_encoding_truncated = []
sampling_params = []
for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments.
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
if len(x) > max_ctx_len:
eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
)
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs)
)
# perform batched generation
cont = self._model_generate(
requests=context_encoding,
requests=context_encoding_truncated,
generate=True,
max_tokens=max_gen_toks,
stop=until,
**kwargs,
sampling_params=sampling_params,
)
# cache generations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment