Unverified Commit 4d6ada94 authored by Swapnil Parekh's avatar Swapnil Parekh Committed by GitHub
Browse files

[CORE] Adding support for insertion of soft-tuned prompts (#4645)


Co-authored-by: default avatarSwapnil Parekh <swapnilp@ibm.com>
Co-authored-by: default avatarJoe G <joseph.granados@h2o.ai>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent a0550cbc
......@@ -4,7 +4,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
......@@ -48,6 +48,7 @@ class TP1DraftModelRunner(ModelRunner):
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
multimodal_config: Optional[MultiModalConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
):
if return_hidden_states:
......@@ -66,6 +67,7 @@ class TP1DraftModelRunner(ModelRunner):
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
)
......@@ -136,6 +138,13 @@ class TP1DraftModelRunner(ModelRunner):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = []
for step in range(num_steps):
......
......@@ -8,7 +8,7 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
......@@ -81,6 +81,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
......@@ -94,6 +95,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self.cache_config = cache_config
self.lora_config = lora_config
self.multimodal_config = multimodal_config
self.prompt_adapter_config = prompt_adapter_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
......
......@@ -7,7 +7,7 @@ import torch.distributed
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
......@@ -133,6 +133,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
......@@ -145,6 +146,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
......@@ -167,6 +169,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
kv_cache_dtype=kv_cache_dtype,
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# initialize_cache.
......
......@@ -5,7 +5,7 @@ import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
......@@ -40,6 +40,7 @@ class EmbeddingModelRunner(
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
):
super().__init__(model_config,
......@@ -51,6 +52,7 @@ class EmbeddingModelRunner(
lora_config=lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config)
@torch.inference_mode()
......@@ -71,6 +73,13 @@ class EmbeddingModelRunner(
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
......
......@@ -25,7 +25,7 @@ except ImportError:
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
......@@ -40,6 +40,10 @@ from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
......@@ -85,6 +89,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
......@@ -97,6 +103,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
......@@ -133,6 +141,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"prompt_adapter_mapping": self.prompt_adapter_mapping,
"prompt_adapter_requests": self.prompt_adapter_requests,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
......@@ -172,6 +182,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
):
......@@ -183,6 +194,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.lora_config = lora_config
self.load_config = load_config
self.is_driver_worker = is_driver_worker
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states
......@@ -232,6 +244,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model: nn.Module # Set after load_model
# Set after load_model.
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
......@@ -240,16 +253,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
multimodal_config=self.multimodal_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
......@@ -274,6 +285,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
self.model = self.lora_manager.create_lora_manager(self.model)
if self.prompt_adapter_config:
self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens, self.device,
self.prompt_adapter_config)
self.model = (
self.prompt_adapter_manager.create_prompt_adapter_manager(
self.model))
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
......@@ -354,6 +374,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_adapter_index_mapping: List[int] = []
prompt_adapter_prompt_mapping: List[int] = []
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
seq_lens: List[int] = []
prefill_seq_lens: List[int] = []
......@@ -504,6 +527,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_tokens.extend(tokens)
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if is_prompt:
assert len(seq_ids) == 1
......@@ -534,6 +558,21 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if prompt_adapter_id > 0 and is_prompt:
prompt_adapter_requests.add(
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.\
prompt_adapter_num_virtual_tokens
pm = [prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
prompt_adapter_index_mapping += pm
prompt_adapter_prompt_mapping.extend(
[prompt_adapter_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
if is_profile_run:
......@@ -618,12 +657,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
prompt_adapter_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0)
batch_size = graph_batch_size
num_decode_tokens = batch_size
......@@ -759,6 +797,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
else:
lora_mapping = None
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
else:
prompt_adapter_mapping = None
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
request_ids_to_seq_ids = {
......@@ -776,7 +822,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids)
finished_requests_ids=finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests,
)
@torch.inference_mode()
def profile_run(self) -> None:
......@@ -878,33 +927,67 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def remove_all_loras(self):
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.remove_all_loras()
self.lora_manager.remove_all_adapters()
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
self.lora_manager.set_active_loras(lora_requests, lora_mapping)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def add_lora(self, lora_request: LoRARequest) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.add_lora(lora_request)
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.remove_lora(lora_id)
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.pin_lora(lora_id)
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> Set[int]:
if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.")
return self.lora_manager.list_loras()
return self.lora_manager.list_adapters()
def remove_all_prompt_adapters(self):
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.remove_all_adapters()
def set_active_prompt_adapters(
self, prompt_adapter_requests: Set[PromptAdapterRequest],
prompt_adapter_mapping: PromptAdapterMapping) -> None:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
self.prompt_adapter_manager.set_active_adapters(
prompt_adapter_requests, prompt_adapter_mapping)
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.add_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
if not self.prompt_adapter_manager:
raise RuntimeError("PromptAdapter is not enabled.")
return self.prompt_adapter_manager.list_adapters()
@torch.inference_mode()
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
......@@ -1063,6 +1146,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
)
self.set_active_loras(set(), lora_mapping)
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
[-1] * batch_size,
[-1] * batch_size,
)
self.set_active_prompt_adapters(
set(), prompt_adapter_mapping)
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name())
......@@ -1189,6 +1280,13 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
......
......@@ -8,7 +8,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
......@@ -16,6 +17,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
......@@ -45,6 +47,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
) -> None:
......@@ -59,6 +62,7 @@ class Worker(LocalOrDistributedWorkerBase):
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.load_config = load_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if parallel_config and is_driver_worker:
assert rank % parallel_config.tensor_parallel_size == 0, \
......@@ -92,6 +96,7 @@ class Worker(LocalOrDistributedWorkerBase):
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
**speculative_args,
)
......@@ -296,6 +301,19 @@ class Worker(LocalOrDistributedWorkerBase):
def list_loras(self) -> Set[int]:
return self.model_runner.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
return self.model_runner.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.remove_lora(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.model_runner.list_prompt_adapters()
@property
def max_model_len(self) -> int:
return self.model_config.max_model_len
......
......@@ -8,7 +8,7 @@ import torch.nn as nn
from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
......@@ -88,6 +88,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
kv_cache_dtype: Optional[str] = "auto",
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
*args,
**kwargs,
......@@ -98,6 +99,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self.lora_config = lora_config
self.load_config = load_config
self.cache_config = cache_config
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.is_driver_worker = is_driver_worker
......
......@@ -10,7 +10,8 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
......@@ -47,6 +48,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
) -> None:
assert device_config.device_type == "xpu"
......@@ -63,6 +65,7 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.prompt_adapter_config = prompt_adapter_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment