Unverified Commit 728c4c8a authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[Hardware][Intel GPU] Add Intel GPU(XPU) inference backend (#3814)


Co-authored-by: default avatarJiang Li <jiang1.li@intel.com>
Co-authored-by: default avatarAbhilash Majumder <abhilash.majumder@intel.com>
Co-authored-by: default avatarAbhilash Majumder <30946547+abhilash1910@users.noreply.github.com>
parent 1f12122b
from typing import List, Optional
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
class XPUExecutor(GPUExecutor):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
), "Speculative decoding not yet supported for XPU backend"
model_config = _verify_and_get_model_config(model_config)
self.model_config = model_config
self.cache_config = cache_config
self.load_config = load_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = None
# Instantiate the worker and load the model to GPU.
self._init_executor()
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
else:
raise NotImplementedError(
"XPU does not support speculative decoding")
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req)
return output
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
if config.dtype == torch.bfloat16:
logger.warning(
"bfloat16 is not fully supported on XPU, casting to float16.")
config.dtype = torch.float16
if not config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
config.enforce_eager = True
return config
import torch.nn as nn import torch.nn as nn
from vllm.utils import is_cpu, is_hip, is_tpu from vllm.utils import is_cpu, is_hip, is_tpu, is_xpu
class CustomOp(nn.Module): class CustomOp(nn.Module):
...@@ -29,9 +29,7 @@ class CustomOp(nn.Module): ...@@ -29,9 +29,7 @@ class CustomOp(nn.Module):
return self.forward_cuda(*args, **kwargs) return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs): def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with CUDA ops. raise NotImplementedError
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)
def forward_cpu(self, *args, **kwargs): def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops. # By default, we assume that CPU ops are compatible with CUDA ops.
...@@ -58,5 +56,7 @@ class CustomOp(nn.Module): ...@@ -58,5 +56,7 @@ class CustomOp(nn.Module):
return self.forward_cpu return self.forward_cpu
elif is_tpu(): elif is_tpu():
return self.forward_tpu return self.forward_tpu
elif is_xpu():
return self.forward_xpu
else: else:
return self.forward_cuda return self.forward_cuda
...@@ -37,6 +37,15 @@ class SiluAndMul(CustomOp): ...@@ -37,6 +37,15 @@ class SiluAndMul(CustomOp):
ops.silu_and_mul(out, x) ops.silu_and_mul(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
class GeluAndMul(CustomOp): class GeluAndMul(CustomOp):
"""An activation function for GeGLU. """An activation function for GeGLU.
...@@ -71,6 +80,18 @@ class GeluAndMul(CustomOp): ...@@ -71,6 +80,18 @@ class GeluAndMul(CustomOp):
ops.gelu_tanh_and_mul(out, x) ops.gelu_tanh_and_mul(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}' return f'approximate={repr(self.approximate)}'
...@@ -90,6 +111,13 @@ class NewGELU(CustomOp): ...@@ -90,6 +111,13 @@ class NewGELU(CustomOp):
ops.gelu_new(out, x) ops.gelu_new(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
class FastGELU(CustomOp): class FastGELU(CustomOp):
...@@ -105,6 +133,13 @@ class FastGELU(CustomOp): ...@@ -105,6 +133,13 @@ class FastGELU(CustomOp):
ops.gelu_fast(out, x) ops.gelu_fast(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
from vllm._ipex_ops import ipex_ops as ops
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters. """An activation function with post-scale parameters.
......
...@@ -67,6 +67,30 @@ class RMSNorm(CustomOp): ...@@ -67,6 +67,30 @@ class RMSNorm(CustomOp):
) )
return out return out
def forward_xpu(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm._ipex_ops import ipex_ops as ops
if residual is not None:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}" s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}" s += f", eps={self.variance_epsilon}"
......
...@@ -221,6 +221,29 @@ class RotaryEmbedding(CustomOp): ...@@ -221,6 +221,29 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style) self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_tpu( def forward_tpu(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
...@@ -307,7 +307,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -307,7 +307,7 @@ class VocabParallelEmbedding(torch.nn.Module):
else: else:
masked_input = input_ masked_input = input_
# Get the embeddings. # Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight) output_parallel = F.embedding(masked_input.long(), self.weight)
# Mask the output embedding. # Mask the output embedding.
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
......
...@@ -160,6 +160,26 @@ def is_tpu() -> bool: ...@@ -160,6 +160,26 @@ def is_tpu() -> bool:
return libtpu is not None return libtpu is not None
@lru_cache(maxsize=None)
def is_xpu() -> bool:
from importlib.metadata import version
is_xpu_flag = "xpu" in version("vllm")
# vllm is not build with xpu
if not is_xpu_flag:
return False
try:
import intel_extension_for_pytorch as ipex # noqa: F401
_import_ipex = True
except ImportError as e:
logger.warning("Import Error for IPEX: %s", e.msg)
_import_ipex = False
# ipex dependency is not ready
if not _import_ipex:
logger.warning("not found ipex lib")
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
...@@ -482,6 +502,9 @@ def is_pin_memory_available() -> bool: ...@@ -482,6 +502,9 @@ def is_pin_memory_available() -> bool:
print_warning_once("Using 'pin_memory=False' as WSL is detected. " print_warning_once("Using 'pin_memory=False' as WSL is detected. "
"This may slow down the performance.") "This may slow down the performance.")
return False return False
elif is_xpu():
print_warning_once("Pin memory is not supported on XPU.")
return False
elif is_neuron(): elif is_neuron():
print_warning_once("Pin memory is not supported on Neuron.") print_warning_once("Pin memory is not supported on Neuron.")
return False return False
...@@ -497,8 +520,12 @@ class CudaMemoryProfiler: ...@@ -497,8 +520,12 @@ class CudaMemoryProfiler:
def current_memory_usage(self) -> float: def current_memory_usage(self) -> float:
# Return the memory usage in bytes. # Return the memory usage in bytes.
torch.cuda.reset_peak_memory_stats(self.device) if torch.cuda.is_available():
mem = torch.cuda.max_memory_allocated(self.device) torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
elif is_xpu():
torch.xpu.reset_peak_memory_stats(self.device)
mem = torch.xpu.max_memory_allocated(self.device)
return mem return mem
def __enter__(self): def __enter__(self):
......
...@@ -4,7 +4,7 @@ from typing import List ...@@ -4,7 +4,7 @@ from typing import List
import torch import torch
from vllm.attention import get_attn_backend from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size,
is_pin_memory_available) is_pin_memory_available)
...@@ -25,10 +25,12 @@ class CacheEngine: ...@@ -25,10 +25,12 @@ class CacheEngine:
cache_config: CacheConfig, cache_config: CacheConfig,
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
device_config: DeviceConfig,
) -> None: ) -> None:
self.cache_config = cache_config self.cache_config = cache_config
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.device_config = device_config
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config) self.num_layers = model_config.get_num_layers(parallel_config)
...@@ -55,7 +57,8 @@ class CacheEngine: ...@@ -55,7 +57,8 @@ class CacheEngine:
) )
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda") self.gpu_cache = self._allocate_kv_cache(
self.num_gpu_blocks, self.device_config.device_type)
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
def _allocate_kv_cache( def _allocate_kv_cache(
......
...@@ -205,7 +205,8 @@ class Worker(WorkerBase): ...@@ -205,7 +205,8 @@ class Worker(WorkerBase):
def _init_cache_engine(self): def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config) self.parallel_config,
self.device_config)
self.gpu_cache = self.cache_engine.gpu_cache self.gpu_cache = self.cache_engine.gpu_cache
def _warm_up_model(self) -> None: def _warm_up_model(self) -> None:
......
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
_BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
class XPUModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
*args,
**kwargs,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.lora_config = lora_config
self.load_config = load_config
self.cache_config = cache_config
self.vision_language_config = vision_language_config
self.is_driver_worker = is_driver_worker
self.sliding_window = model_config.get_sliding_window()
self.device_config = device_config
self.device = self.device_config.device
self.kv_cache_dtype = kv_cache_dtype
self.block_size = cache_config.block_size
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
self.attn_backend = get_attn_backend(
self.model_config.get_num_attention_heads(self.parallel_config),
self.model_config.get_head_size(),
self.model_config.get_num_kv_heads(self.parallel_config),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
)
# Lazy initialization.
self.model: nn.Module # Set after init_Model
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
vision_language_config=self.vision_language_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
# Additional GPU memory may be needed for vision encoding, which needs
# to be accounted for when calculating the GPU blocks for
# vLLM blocker manager.
# To exercise the worst scenario for GPU memory consumption,
# the number of seqs (batch_size) is chosen to maximize the number
# of images processed.
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
dummy_multi_modal_data = None
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
lora_request=None,
multi_modal_data=dummy_multi_modal_data,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches)
torch.xpu.synchronize()
return
def prepare_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[torch.Tensor]]:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# subquery_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
}
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seqlen_q=None,
max_seqlen=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_seq_len=max_decode_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
num_prefills=0,
block_tables=block_tables,
)
return (
input_tokens,
input_positions,
attn_metadata,
)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Optional[torch.Tensor]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_input_list: List[torch.Tensor] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
computed_len = seq_data.get_num_computed_tokens()
seq_len = len(prompt_tokens)
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids
# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, seq_len - self.sliding_window)
for i in range(computed_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
num_prompt_tokens = len(input_tokens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
max_seqlen = max(seq_lens)
tmp = [0]
tmp.extend(seq_lens)
seqlen = torch.tensor(tmp)
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seqlen_q=seqlen_q,
max_seqlen=max_seqlen,
seq_lens_tensor=None,
max_decode_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_input)
"""A XPU worker class."""
import gc
import os
from typing import List, Optional, Tuple
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.utils import is_xpu
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.xpu_model_runner import XPUModelRunner
logger = init_logger(__name__)
class XPUWorker(LoraNotSupportedWorkerBase, Worker):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single XPU device. The worker is
responsible for maintaining the KV cache and executing the model on the
XPU. In case of distributed inference, each worker is assigned a partition
of the model.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
cache_config: CacheConfig,
load_config: LoadConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
vision_language_config: Optional[VisionLanguageConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
is_driver_worker: bool = False,
) -> None:
assert device_config.device_type == "xpu"
assert is_xpu()
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.cache_config = cache_config
self.load_config = load_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.vision_language_config = vision_language_config
if self.vision_language_config:
assert not self.lora_config, (
"To be tested: vision language model with LoRA settings.")
self.model_runner = XPUModelRunner( # type: ignore
model_config,
parallel_config,
scheduler_config,
device_config,
cache_config,
load_config=self.load_config,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
vision_language_config=vision_language_config,
)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
self.cache_engine: CacheEngine
self.gpu_cache: List[torch.Tensor]
def init_device(self) -> None:
if self.device_config.device.type == "xpu" and is_xpu():
self.device = torch.device(f"xpu:{self.local_rank}")
torch.xpu.set_device(self.device)
torch.xpu.empty_cache()
self.init_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
self.init_worker_distributed_environment()
# Initialize the model.
set_random_seed(self.model_config.seed)
# keep this method for `empty_cache` and `synchronize` api
@torch.inference_mode()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.xpu.synchronize()
used_memory = torch.xpu.memory_allocated()
total_gpu_memory = torch.xpu.get_device_properties(
self.local_rank).total_memory
free_gpu_memory = total_gpu_memory - used_memory
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes()
num_gpu_blocks = int(
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory) // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
gc.collect()
torch.xpu.empty_cache()
return num_gpu_blocks, num_cpu_blocks
def _warm_up_model(self) -> None:
# IPEX don't support capture graph yet
pass
def init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
rank = self.rank
distributed_init_method = self.distributed_init_method
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch "
"world size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
# use sockets as default Level zero IPC exchange backend. By
# default oneccl will use `drmfd` as mechanism which need extra
# dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
"sockets")
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
init_distributed_environment(
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=self.local_rank,
backend="ccl")
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment