Unverified Commit 838a3e03 authored by Chris Kerwell Gresla's avatar Chris Kerwell Gresla Committed by GitHub
Browse files

Fix lora requests when dp with vllm (#2433)



* fix: use lora_request for data parallel vllm evals

* fix(docs): include type hint

* chore: lint, et pre-commit al

---------
Co-authored-by: default avatarChris Kerwell Gresla <chris@wafer.systems>
parent 7882043b
...@@ -239,17 +239,25 @@ class VLLM(TemplateLM): ...@@ -239,17 +239,25 @@ class VLLM(TemplateLM):
# but then tensor_parallel breaks # but then tensor_parallel breaks
@ray.remote @ray.remote
def run_inference_one_model( def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]] model_args: dict,
sampling_params,
requests: List[List[int]],
lora_request: LoRARequest,
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate( return llm.generate(
prompt_token_ids=requests, sampling_params=sampling_params prompt_token_ids=requests,
sampling_params=sampling_params,
lora_request=lora_request,
) )
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion # dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers # interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)] requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = ((self.model_args, sampling_params, req) for req in requests) inputs = (
(self.model_args, sampling_params, req, self.lora_request)
for req in requests
)
object_refs = [run_inference_one_model.remote(*x) for x in inputs] object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs) results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
...@@ -257,19 +265,12 @@ class VLLM(TemplateLM): ...@@ -257,19 +265,12 @@ class VLLM(TemplateLM):
# flatten results # flatten results
return undistribute(results) return undistribute(results)
if self.lora_request is not None: outputs = self.model.generate(
outputs = self.model.generate( prompt_token_ids=requests,
prompt_token_ids=requests, sampling_params=sampling_params,
sampling_params=sampling_params, use_tqdm=True if self.batch_size == "auto" else False,
use_tqdm=True if self.batch_size == "auto" else False, lora_request=self.lora_request,
lora_request=self.lora_request, )
)
else:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
return outputs return outputs
def loglikelihood_rolling( def loglikelihood_rolling(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment