Unverified Commit d59c9864 authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Remove runtime checks based on pooling params (#24051)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
parent 04d0c607
...@@ -704,17 +704,12 @@ class InputBatch: ...@@ -704,17 +704,12 @@ class InputBatch:
logitsprocs=self.logitsprocs, logitsprocs=self.logitsprocs,
) )
@property def get_pooling_params(self) -> list[PoolingParams]:
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params) assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [ return [self.pooling_params[req_id] for req_id in self.req_ids]
self.pooling_params[req_id] for req_id in self.req_ids
] def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
return PoolingMetadata( return PoolingMetadata(
prompt_lens=torch.from_numpy( prompt_lens=torch.from_numpy(
......
...@@ -138,7 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -138,7 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype] cache_config.cache_dtype]
self.is_pooling_model = model_config.pooler_config is not None self.is_pooling_model = (model_config.runner_type == 'pooling')
self.is_multimodal_raw_input_only_model = ( self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model) model_config.is_multimodal_raw_input_only_model)
...@@ -332,17 +332,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -332,17 +332,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _init_model_kwargs(self, num_tokens: int): def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]() model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
num_pooling_reqs = len(self.input_batch.pooling_params)
if num_pooling_reqs == 0: if not self.is_pooling_model:
return model_kwargs return model_kwargs
# This does nontrivial work. num_reqs = self.input_batch.num_reqs
pooling_params = self.input_batch.pooling_metadata.pooling_params pooling_params = self.input_batch.get_pooling_params()
assert num_pooling_reqs == num_reqs
token_type_id_requests = dict[int, Any]() token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params): for i, param in enumerate(pooling_params):
...@@ -456,7 +451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -456,7 +451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else: else:
generator = None generator = None
if pooling_params: if self.is_pooling_model:
assert pooling_params is not None
task = pooling_params.task task = pooling_params.task
assert task is not None, "You did not set `task` in the API" assert task is not None, "You did not set `task` in the API"
...@@ -1437,7 +1433,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1437,7 +1433,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
" a batch must be pooling request" " a batch must be pooling request"
hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.input_batch.pooling_metadata pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device) device=hidden_states.device)
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
...@@ -1609,7 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -1609,7 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
all_gather_group=get_tp_group()) all_gather_group=get_tp_group())
logits = None logits = None
else: else:
if self.input_batch.pooling_params: if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens, return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, kv_connector_output) num_scheduled_tokens_np, kv_connector_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment