Unverified Commit 44398478 authored by Slim Frikha's avatar Slim Frikha Committed by GitHub
Browse files

Ignore seed when splitting batch in chunks with groupby (#3047)



* feat(vllm_causallms): make collator ignore seed when splitting batch into chunks

* fix(collator): revert PR changes

* fix(vllm-causallm): update collator call with groupby None

* feat(sglang-causallms): make generation accept a list of sampling params

---------
Co-authored-by: default avatarBaber <baber@hey.com>
parent 4f1e9f7c
...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM): ...@@ -216,7 +216,7 @@ class SGLangLM(TemplateLM):
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch. # in the same batch.
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") re_ords = Collator(requests, _collate_gen, group_by=None)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM): ...@@ -232,36 +232,41 @@ class SGLangLM(TemplateLM):
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences # add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
if len(x) > max_ctx_len:
context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
kwargs | {"max_tokens": max_gen_toks, "stop": until}
) )
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks
context_encoding = [x[-max_ctx_len:] for x in context_encoding]
# perform batched generation # perform batched generation
# cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 . # cont is a list of dic. See here https://github.com/sgl-project/sglang/blob/0a6f18f068e4095fc228e798454e8496c9749214/python/sglang/srt/entrypoints/engine.py#L111 .
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM): ...@@ -284,28 +289,22 @@ class SGLangLM(TemplateLM):
self, self,
requests: List[List[int]] = None, requests: List[List[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: Union[List[Dict], Dict, None] = None,
stop: Optional[List[str]] = None,
return_logprob: bool = False, return_logprob: bool = False,
top_logprobs_num: int = 1, top_logprobs_num: int = 1,
logprob_start_len: int = -1, logprob_start_len: int = -1,
**kwargs,
): ):
# check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html. # check sglang sampling parameters: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/sampling/sampling_params.py#L21 and https://docs.sglang.ai/references/sampling_params.html.
if generate: if not generate:
kwargs = self.modify_gen_kwargs(kwargs) sampling_params = sampling_params if sampling_params else {}
sampling_params = { sampling_params.update(
"max_new_tokens": max_tokens, {
"stop": stop, "temperature": 0,
} "max_new_tokens": 1,
sampling_params.update(kwargs) }
else: )
sampling_params = { if not isinstance(sampling_params, List):
"temperature": 0, sampling_params = [sampling_params] * len(requests)
"max_new_tokens": 1,
}
sampling_params.update(kwargs)
# Refer to: https://docs.sglang.ai/backend/offline_engine_api.html # Refer to: https://docs.sglang.ai/backend/offline_engine_api.html
outputs = self.model.generate( outputs = self.model.generate(
input_ids=requests, input_ids=requests,
......
...@@ -50,7 +50,7 @@ eval_logger = logging.getLogger(__name__) ...@@ -50,7 +50,7 @@ eval_logger = logging.getLogger(__name__)
def _vllm_mp_worker( def _vllm_mp_worker(
model_args: dict, model_args: dict,
sampling_params: "SamplingParams", sampling_params: "list[SamplingParams]",
requests: list[list[int]], requests: list[list[int]],
lora_request: "LoRARequest", lora_request: "LoRARequest",
result_queue: "Queue", result_queue: "Queue",
...@@ -364,17 +364,14 @@ class VLLM(TemplateLM): ...@@ -364,17 +364,14 @@ class VLLM(TemplateLM):
self, self,
requests: List[List[int]] = None, requests: List[List[int]] = None,
generate: bool = False, generate: bool = False,
max_tokens: int = None, sampling_params: Union[List[SamplingParams], SamplingParams, None] = None,
stop: Optional[List[str]] = None,
**kwargs,
): ):
if generate: if not generate or sampling_params is None:
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False
) )
if not isinstance(sampling_params, List):
sampling_params = [sampling_params] * len(requests)
if self.data_parallel_size > 1 and not self.V1: if self.data_parallel_size > 1 and not self.V1:
# vLLM hangs if resources are set in ray.remote # vLLM hangs if resources are set in ray.remote
# also seems to only work with decorator and not with ray.remote() fn # also seems to only work with decorator and not with ray.remote() fn
...@@ -382,7 +379,7 @@ class VLLM(TemplateLM): ...@@ -382,7 +379,7 @@ class VLLM(TemplateLM):
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, model_args: dict,
sampling_params: SamplingParams, sampling_params: List[SamplingParams],
requests: List[List[int]], requests: List[List[int]],
lora_request: LoRARequest, lora_request: LoRARequest,
): ):
...@@ -396,9 +393,12 @@ class VLLM(TemplateLM): ...@@ -396,9 +393,12 @@ class VLLM(TemplateLM):
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
sampling_params = [
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
]
inputs = ( inputs = (
(self.model_args, sampling_params, req, self.lora_request) (self.model_args, sp, req, self.lora_request)
for req in requests for req, sp in zip(requests, sampling_params)
) )
object_refs = [run_inference_one_model.remote(*x) for x in inputs] object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs) results = ray.get(object_refs)
...@@ -413,16 +413,18 @@ class VLLM(TemplateLM): ...@@ -413,16 +413,18 @@ class VLLM(TemplateLM):
dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port() dp_master_port = os.environ.get("VLLM_DP_MASTER_PORT") or get_open_port()
requests = (list(x) for x in distribute(self.data_parallel_size, requests)) requests = (list(x) for x in distribute(self.data_parallel_size, requests))
sampling_params = (
list(sp) for sp in distribute(self.data_parallel_size, sampling_params)
)
procs, resq = [], Queue() procs, resq = [], Queue()
# We use Process as it is non-daemonic # We use Process as it is non-daemonic
try: try:
for rank, req in enumerate(requests): for rank, (sp, req) in enumerate(zip(requests, sampling_params)):
proc = Process( proc = Process(
target=_vllm_mp_worker, target=_vllm_mp_worker,
args=( args=(
self.model_args.copy(), self.model_args.copy(),
sampling_params, sp,
req, req,
self.lora_request, self.lora_request,
resq, resq,
...@@ -576,10 +578,11 @@ class VLLM(TemplateLM): ...@@ -576,10 +578,11 @@ class VLLM(TemplateLM):
# - any OOMs will happen right away rather than near the end # - any OOMs will happen right away rather than near the end
return -len(_requests[0][1]), _requests[0][0] return -len(_requests[0][1]), _requests[0][0]
# we group requests by their generation_kwargs, re_ords = Collator(
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling requests,
# in the same batch. _collate_gen,
re_ords = Collator(requests, _collate_gen, group_by="gen_kwargs") group_by=None,
)
chunks = re_ords.get_batched( chunks = re_ords.get_batched(
n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None n=int(self.batch_size) if self.batch_size != "auto" else 0, batch_fn=None
) )
...@@ -594,41 +597,44 @@ class VLLM(TemplateLM): ...@@ -594,41 +597,44 @@ class VLLM(TemplateLM):
for chunk in chunks: for chunk in chunks:
context_and_encoding, all_gen_kwargs = zip(*chunk) context_and_encoding, all_gen_kwargs = zip(*chunk)
context, context_encoding = zip(*context_and_encoding) context, context_encoding = zip(*context_and_encoding)
# we assume all gen kwargs in the batch are the same context_encoding_truncated = []
# this is safe to assume because the `grouper` object ensures it. sampling_params = []
gen_kwargs = all_gen_kwargs[0] for x, gen_kwargs in zip(context_encoding, all_gen_kwargs):
# unpack our keyword arguments. # unpack our keyword arguments.
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
# add EOS token to stop sequences # add EOS token to stop sequences
until = handle_stop_sequences(kwargs.pop("until", None), eos=eos) until = handle_stop_sequences(kwargs.pop("until", None), eos=eos)
else: else:
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
all_lengths = [len(x) for x in context_encoding] if len(x) > max_ctx_len:
for length in all_lengths:
if length > max_ctx_len:
eval_logger.warning( eval_logger.warning(
f"Context length {length} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context." f"Context length {len(x)} exceeds max length (context + max gen tokens): {max_ctx_len}. Truncating context."
) )
context_encoding = [x[-max_ctx_len:] for x in context_encoding] context_encoding_truncated.append(x[-max_ctx_len:])
else:
context_encoding_truncated.append(x)
# create sampling params
kwargs = self.modify_gen_kwargs(kwargs)
sampling_params.append(
SamplingParams(max_tokens=max_gen_toks, stop=until, **kwargs)
)
# perform batched generation # perform batched generation
cont = self._model_generate( cont = self._model_generate(
requests=context_encoding, requests=context_encoding_truncated,
generate=True, generate=True,
max_tokens=max_gen_toks, sampling_params=sampling_params,
stop=until,
**kwargs,
) )
# cache generations # cache generations
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment