Commit 4b4eeb26 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents 2216a4e5 4fdc581f
......@@ -5,10 +5,12 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
# neuron has too old torch
HAS_TRITON = find_spec(
"triton") is not None and not current_platform.is_neuron()
HAS_TRITON = (
find_spec("triton") is not None
and not current_platform.is_xpu() # Not compatible
and not current_platform.is_neuron() # neuron has too old torch
)
if not HAS_TRITON:
logger.info("Triton not installed; certain GPU-related functions"
" will not be available.")
logger.info("Triton not installed or not compatible; certain GPU-related"
" functions will not be available.")
......@@ -327,29 +327,6 @@ def is_openvino() -> bool:
return False
@lru_cache(maxsize=None)
def is_xpu() -> bool:
from importlib.metadata import PackageNotFoundError, version
try:
is_xpu_flag = "xpu" in version("vllm")
except PackageNotFoundError:
return False
# vllm is not build with xpu
if not is_xpu_flag:
return False
try:
import intel_extension_for_pytorch as ipex # noqa: F401
_import_ipex = True
except ImportError as e:
logger.warning("Import Error for IPEX: %s", e.msg)
_import_ipex = False
# ipex dependency is not ready
if not _import_ipex:
logger.warning("not found ipex lib")
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
......@@ -379,7 +356,7 @@ def seed_everything(seed: int) -> None:
if current_platform.is_cuda_alike():
torch.cuda.manual_seed_all(seed)
if is_xpu():
if current_platform.is_xpu():
torch.xpu.manual_seed_all(seed)
......@@ -774,7 +751,7 @@ def is_pin_memory_available() -> bool:
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.")
return False
elif is_xpu():
elif current_platform.is_xpu():
print_warning_once("Pin memory is not supported on XPU.")
return False
elif current_platform.is_neuron():
......@@ -795,7 +772,7 @@ class DeviceMemoryProfiler:
if current_platform.is_cuda_alike():
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu():
elif current_platform.is_xpu():
torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
return mem
......
......@@ -300,6 +300,7 @@ class LLMEngine:
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
self.scheduler.finish_requests(request_id,
RequestStatus.FINISHED_ABORTED)
self._free_request(request_id)
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
......@@ -361,6 +362,11 @@ class LLMEngine:
num_reqs = len(detokenizer_output.req_ids)
for i in range(num_reqs):
req_id = detokenizer_output.req_ids[i]
if req_id not in self.requests:
# The request has been aborted while the detokenizer was
# processing the outputs.
continue
req = self.requests[req_id]
req.output_text += detokenizer_output.detokenized_texts[i]
......@@ -373,9 +379,7 @@ class LLMEngine:
req_outputs.append(req_output)
if finished:
del self.requests[req_id]
del self.num_lagged_steps[req_id]
del self.request_outputs[req_id]
self._free_request(req_id)
return req_outputs
def terminate_detokenizer(self) -> None:
......@@ -440,6 +444,11 @@ class LLMEngine:
req_output.finished = finished
return req_output
def _free_request(self, request_id: str) -> None:
self.requests.pop(request_id, None)
self.num_lagged_steps.pop(request_id, None)
self.request_outputs.pop(request_id, None)
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
......
......@@ -17,7 +17,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import is_xpu
from vllm.platforms import current_platform
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
......@@ -53,7 +53,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
observability_config: Optional[ObservabilityConfig] = None,
) -> None:
assert device_config.device_type == "xpu"
assert is_xpu()
assert current_platform.is_xpu()
self.model_config = model_config
self.parallel_config = parallel_config
......@@ -91,7 +91,8 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self.gpu_cache: Optional[List[List[torch.Tensor]]]
def init_device(self) -> None:
if self.device_config.device.type == "xpu" and is_xpu():
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
):
self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device)
torch.xpu.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment