Unverified Commit bc811365 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

Update vLLM compatibility (#3024)



* Update vLLM compatibility
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>

* add TokensPrompt to all generate calls

---------
Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarBaber <baber@hey.com>
parent 4f8195f1
import copy import copy
import gc import gc
import inspect
import logging import logging
import os import os
from importlib.metadata import version from importlib.metadata import version
...@@ -33,7 +32,7 @@ from lm_eval.utils import ( ...@@ -33,7 +32,7 @@ from lm_eval.utils import (
try: try:
import ray import ray
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams, TokensPrompt
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import get_open_port from vllm.utils import get_open_port
...@@ -79,7 +78,7 @@ def _vllm_mp_worker( ...@@ -79,7 +78,7 @@ def _vllm_mp_worker(
try: try:
llm = LLM(**model_args) llm = LLM(**model_args)
res = llm.generate( res = llm.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -239,13 +238,6 @@ class VLLM(TemplateLM): ...@@ -239,13 +238,6 @@ class VLLM(TemplateLM):
model_config = engine_args.create_model_config() model_config = engine_args.create_model_config()
kwargs_resolve_hf_chat_template["model_config"] = model_config kwargs_resolve_hf_chat_template["model_config"] = model_config
# https://github.com/vllm-project/vllm/pull/18259
if (
"trsut_remote_code"
in inspect.signature(resolve_hf_chat_template).parameters
):
kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code
else: else:
kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code
...@@ -395,7 +387,7 @@ class VLLM(TemplateLM): ...@@ -395,7 +387,7 @@ class VLLM(TemplateLM):
): ):
llm = LLM(**model_args) llm = LLM(**model_args)
return llm.generate( return llm.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -484,7 +476,7 @@ class VLLM(TemplateLM): ...@@ -484,7 +476,7 @@ class VLLM(TemplateLM):
else: else:
outputs = self.model.generate( outputs = self.model.generate(
prompt_token_ids=requests, [TokensPrompt(prompt_token_ids=request) for request in requests],
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False, use_tqdm=True if self.batch_size == "auto" else False,
lora_request=self.lora_request, lora_request=self.lora_request,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment