Unverified Commit 838a3e03 authored by Chris Kerwell Gresla's avatar Chris Kerwell Gresla Committed by GitHub
Browse files

Fix lora requests when dp with vllm (#2433)



* fix: use lora_request for data parallel vllm evals

* fix(docs): include type hint

* chore: lint, et pre-commit al

---------
Co-authored-by: default avatarChris Kerwell Gresla <chris@wafer.systems>
parent 7882043b
......@@ -239,17 +239,25 @@ class VLLM(TemplateLM):
# but then tensor_parallel breaks
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[int]]
model_args: dict,
sampling_params,
requests: List[List[int]],
lora_request: LoRARequest,
):
llm = LLM(**model_args)
return llm.generate(
prompt_token_ids=requests, sampling_params=sampling_params
prompt_token_ids=requests,
sampling_params=sampling_params,
lora_request=lora_request,
)
# dispatch requests to all self.data_parallel_size workers, in interleaved fashion
# interleaved important to balance context lengths across workers
requests = [list(x) for x in distribute(self.data_parallel_size, requests)]
inputs = ((self.model_args, sampling_params, req) for req in requests)
inputs = (
(self.model_args, sampling_params, req, self.lora_request)
for req in requests
)
object_refs = [run_inference_one_model.remote(*x) for x in inputs]
results = ray.get(object_refs)
# Invoke ray.shutdown() to prevent hang-ups if subsequent calls required.
......@@ -257,19 +265,12 @@ class VLLM(TemplateLM):
# flatten results
return undistribute(results)
if self.lora_request is not None:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
)
outputs = self.model.generate(
prompt_token_ids=requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request,
)
return outputs
def loglikelihood_rolling(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment