Unverified Commit 9c1f78d5 authored by omrishiv's avatar omrishiv Committed by GitHub
Browse files

[Bugfix] update neuron for version > 0.5.0 (#7175)


Signed-off-by: default avataromrishiv <327609+omrishiv@users.noreply.github.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent fc93e561
...@@ -316,7 +316,7 @@ class EngineArgs: ...@@ -316,7 +316,7 @@ class EngineArgs:
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32], choices=[8, 16, 32, 128, 256, 512, 1024, 2048],
help='Token block size for contiguous chunks of ' help='Token block size for contiguous chunks of '
'tokens.') 'tokens.')
......
...@@ -100,9 +100,8 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): ...@@ -100,9 +100,8 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
self, self,
execute_model_req: ExecuteModelRequest, execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
output = await make_async( output = await make_async(self.driver_worker.execute_model
self.driver_worker.execute_model )(execute_model_req=execute_model_req, )
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
return output return output
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
......
...@@ -197,6 +197,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): ...@@ -197,6 +197,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
virtual_engine: int = 0, virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForNeuron: ) -> ModelInputForNeuron:
multi_modal_kwargs = None
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
......
...@@ -89,6 +89,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -89,6 +89,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
return WorkerInput(num_seq_groups=len( return WorkerInput(num_seq_groups=len(
execute_model_req.seq_group_metadata_list), ) execute_model_req.seq_group_metadata_list), )
def execute_worker(self, worker_input: WorkerInput) -> None:
pass
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block. """Determine the size in bytes of a cache block.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment