Unverified Commit d59c9864 authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Remove runtime checks based on pooling params (#24051)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
parent 04d0c607
......@@ -704,17 +704,12 @@ class InputBatch:
logitsprocs=self.logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
def get_pooling_params(self) -> list[PoolingParams]:
assert len(self.req_ids) == len(self.pooling_params)
return [self.pooling_params[req_id] for req_id in self.req_ids]
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
return PoolingMetadata(
prompt_lens=torch.from_numpy(
......
......@@ -138,7 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.is_pooling_model = model_config.pooler_config is not None
self.is_pooling_model = (model_config.runner_type == 'pooling')
self.is_multimodal_raw_input_only_model = (
model_config.is_multimodal_raw_input_only_model)
......@@ -332,17 +332,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
num_pooling_reqs = len(self.input_batch.pooling_params)
if num_pooling_reqs == 0:
if not self.is_pooling_model:
return model_kwargs
# This does nontrivial work.
pooling_params = self.input_batch.pooling_metadata.pooling_params
assert num_pooling_reqs == num_reqs
num_reqs = self.input_batch.num_reqs
pooling_params = self.input_batch.get_pooling_params()
token_type_id_requests = dict[int, Any]()
for i, param in enumerate(pooling_params):
......@@ -456,7 +451,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
generator = None
if pooling_params:
if self.is_pooling_model:
assert pooling_params is not None
task = pooling_params.task
assert task is not None, "You did not set `task` in the API"
......@@ -1437,7 +1433,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
" a batch must be pooling request"
hidden_states = hidden_states[:num_scheduled_tokens]
pooling_metadata = self.input_batch.pooling_metadata
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device)
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
......@@ -1609,7 +1605,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
all_gather_group=get_tp_group())
logits = None
else:
if self.input_batch.pooling_params:
if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, kv_connector_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment