Unverified Commit 9d6a8daa authored by Mor Zusman's avatar Mor Zusman Committed by GitHub
Browse files
parent ee93f4f9
......@@ -23,4 +23,4 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
cd ../
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
......@@ -43,6 +43,10 @@ COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt
COPY requirements-mamba.txt requirements-mamba.txt
RUN python3 -m pip install packaging
RUN python3 -m pip install -r requirements-mamba.txt
# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
......@@ -123,6 +127,21 @@ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-dev.txt
#################### DEV IMAGE ####################
#################### MAMBA Build IMAGE ####################
FROM dev as mamba-builder
# max jobs used for build
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
WORKDIR /usr/src/mamba
COPY requirements-mamba.txt requirements-mamba.txt
# Download the wheel or build it if a pre-compiled release doesn't exist
RUN pip --verbose wheel -r requirements-mamba.txt \
--no-build-isolation --no-deps --no-cache-dir
#################### MAMBA Build IMAGE ####################
#################### vLLM installation IMAGE ####################
# image with vLLM installed
......@@ -143,6 +162,10 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
#################### vLLM installation IMAGE ####################
......
......@@ -87,6 +87,10 @@ Alongside each architecture, we include some popular models that use it.
- Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc.
-
* - :code:`JambaForCausalLM`
- Jamba
- :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
* - :code:`LlamaForCausalLM`
- LLaMA, Llama 2, Meta Llama 3, Vicuna, Alpaca, Yi
- :code:`meta-llama/Meta-Llama-3-8B-Instruct`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
......
# Mamba dependencies
mamba-ssm>=1.2.2
causal-conv1d>=1.2.0
import pytest
MODELS = ["ai21labs/Jamba-tiny-random"]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_state_cleanup(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is for verifying that the Jamba state is cleaned up between
# steps, If its not cleaned, an error would be expected.
try:
with vllm_runner(model, dtype=dtype) as vllm_model:
for _ in range(10):
vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
except ValueError:
pytest.fail("Jamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
......@@ -386,9 +386,36 @@ class ModelConfig:
return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
"""True for Mamba/SSM models (Jamba)"""
return self._get_num_seqlen_agnostic_layers(parallel_config) > 0
def get_layers_block_type(self,
parallel_config: "ParallelConfig") -> List[str]:
num_layers = self.get_num_layers(parallel_config)
# Transformers supports layers_block_type @property
return getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t == "attention"
])
def _get_num_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> int:
return len([
t for t in self.get_layers_block_type(parallel_config)
if t != "attention"
])
class CacheConfig:
"""Configuration for the KV cache.
......
......@@ -299,7 +299,10 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step
self.prev_time = 0.0
# Did we schedule a prompt at previous step?
......@@ -373,6 +376,12 @@ class Scheduler:
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
def get_and_reset_finished_requests_ids(self) -> List[str]:
"""Flushes the list of request ids of previously finished seq_groups."""
finished_requests_ids = self._finished_requests_ids
self._finished_requests_ids = list()
return finished_requests_ids
def _schedule_running(
self,
running_queue: deque,
......@@ -1036,6 +1045,11 @@ class Scheduler:
self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]:
self._finished_requests_ids += [
seq_group.request_id for seq_group in queue
if seq_group.is_finished()
]
self.running = deque(seq_group for seq_group in self.running
if not seq_group.is_finished())
......
......@@ -224,6 +224,8 @@ class _AsyncLLMEngine(LLMEngine):
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()
if not scheduler_outputs.is_empty():
# Execute the model.
......@@ -235,7 +237,7 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
finished_requests_ids=finished_requests_ids)
output = await self.model_executor.execute_model_async(
execute_model_req)
else:
......
......@@ -846,6 +846,8 @@ class LLMEngine:
"as performance will be severely degraded otherwise.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
if not scheduler_outputs.is_empty():
execute_model_req = ExecuteModelRequest(
......@@ -855,7 +857,7 @@ class LLMEngine:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
finished_requests_ids=finished_requests_ids)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
else:
......
......@@ -63,6 +63,7 @@ _GENERATION_MODELS = {
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
}
_EMBEDDING_MODELS = {
......
This diff is collapsed.
......@@ -934,6 +934,8 @@ class ExecuteModelRequest:
previous_hidden_states: Optional[HiddenStates] = None
# The number of forward steps to run.
num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
......@@ -949,4 +951,4 @@ class ExecuteModelRequest:
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps,
)
finished_requests_ids=self.finished_requests_ids)
......@@ -75,15 +75,19 @@ class TP1DraftModelRunner(ModelRunner):
List[SequenceGroupMetadata]] = None
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0) -> ModelInputForGPUWithSamplingMetadata:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(seq_group_metadata_list)
return super().prepare_model_input(
seq_group_metadata_list,
finished_requests_ids=finished_requests_ids)
def update_model_input(
self, model_input: ModelInputForGPUWithSamplingMetadata,
......
......@@ -33,7 +33,9 @@ class CacheEngine:
self.device_config = device_config
self.head_size = model_config.get_head_size()
self.num_layers = model_config.get_num_layers(parallel_config)
# Models like Jamba, have mixed typed layers, E.g Mamba
self.num_attention_layers = model_config.get_num_attention_layers(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.block_size = cache_config.block_size
......@@ -75,7 +77,7 @@ class CacheEngine:
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
......@@ -87,12 +89,12 @@ class CacheEngine:
return kv_cache
def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers):
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst)
def swap_out(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers):
for i in range(self.num_attention_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)
......@@ -107,11 +109,12 @@ class CacheEngine:
) -> int:
head_size = model_config.get_head_size()
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
num_attention_layers = model_config.get_num_attention_layers(
parallel_config)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
else:
......
......@@ -314,9 +314,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> CPUModelInput:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or
......
......@@ -120,10 +120,11 @@ class EmbeddingModelRunner(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list)
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
......
......@@ -84,6 +84,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
......@@ -94,6 +96,8 @@ class ModelInputForGPU(ModelRunnerInputBase):
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
......@@ -128,6 +132,8 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
"virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
......@@ -191,6 +197,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
]
self.graph_memory_pool: Optional[Tuple[
int, int]] = None # Set during graph capture.
self.has_seqlen_agnostic = model_config.contains_seqlen_agnostic_layers(
parallel_config)
# When using CUDA graph, the input block tables must be padded to
# max_seq_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
......@@ -317,6 +327,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
......@@ -347,6 +358,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
block_tables: List[List[int]] = []
multi_modal_kwargs_list: Dict[str,
List[torch.Tensor]] = defaultdict(list)
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
decode_only = True
num_prefills = 0
num_prefill_tokens = 0
......@@ -738,7 +750,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
k: torch.cat(v, dim=0).to(self.device)
for k, v in multi_modal_kwargs_list.items()
}
request_ids_to_seq_ids = {
seq_group_metadata.request_id:
list(seq_group_metadata.seq_data.keys())
for seq_group_metadata in seq_group_metadata_list
}
return self._model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
......@@ -748,7 +764,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
)
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids)
@torch.inference_mode()
def profile_run(self) -> None:
......@@ -821,7 +838,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
model_input = self.prepare_model_input(seqs)
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
......@@ -1033,21 +1052,37 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
graph_runner.flashinfer_decode_wrapper = \
decode_wrapper
graph_runner.capture(
capture_inputs = {
"input_ids":
input_tokens[:batch_size],
"positions":
input_positions[:batch_size],
"hidden_or_intermediate_states":
hidden_or_intermediate_states[
virtual_engine] # type: ignore
[:batch_size]
if hidden_or_intermediate_states[virtual_engine]
is not None else None,
"intermediate_inputs":
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
"kv_caches":
kv_caches[virtual_engine],
"attn_metadata":
attn_metadata,
memory_pool=self.graph_memory_pool,
stream=graph_capture_context.stream,
)
"memory_pool":
self.graph_memory_pool,
"stream":
graph_capture_context.stream
}
if self.has_seqlen_agnostic:
# Only used by Mamba-based models CUDA graph atm (Jamba)
capture_inputs.update({
"seqlen_agnostic_capture_inputs":
self.model.get_seqlen_agnostic_capture_inputs(
batch_size)
})
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
......@@ -1084,6 +1119,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
......@@ -1099,7 +1135,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
If cuda graph is required, this API automatically pads inputs.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list)
seq_group_metadata_list, finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
......@@ -1175,6 +1211,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......@@ -1182,7 +1222,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
)
**seqlen_agnostic_kwargs)
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
......@@ -1305,6 +1345,7 @@ class CUDAGraphRunner:
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
**kwargs,
}
else:
self.input_buffers = {
......@@ -1315,6 +1356,7 @@ class CUDAGraphRunner:
"seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
**kwargs,
}
if intermediate_inputs is not None:
self.input_buffers.update(intermediate_inputs.tensors)
......@@ -1349,13 +1391,18 @@ class CUDAGraphRunner:
non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
# Run the graph.
self.graph.replay()
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
**kwargs)
# Return the output tensor.
if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"]
......
......@@ -139,6 +139,7 @@ class ModelRunnerBase(ABC, Generic[T]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
......
......@@ -177,6 +177,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
......
......@@ -234,7 +234,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine))
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment