Commit b92312df authored by haosdent's avatar haosdent Committed by khluu
Browse files

[CI] Fix SPLADE pooler test broken by #38139 (#38495)


Signed-off-by: default avatarhaosdent <haosdent@gmail.com>
(cherry picked from commit a08b7733)
parent d816834c
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -11,6 +9,8 @@ from vllm.model_executor.models.bert import ( ...@@ -11,6 +9,8 @@ from vllm.model_executor.models.bert import (
BertMLMHead, BertMLMHead,
SPLADESparsePooler, SPLADESparsePooler,
) )
from vllm.pooling_params import PoolingParams
from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# Functional test: SPLADE formula correctness (no HF download needed) # Functional test: SPLADE formula correctness (no HF download needed)
...@@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V): ...@@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V):
], ],
dtype=torch.long, dtype=torch.long,
) )
meta = types.SimpleNamespace( meta = PoolingMetadata(
prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids prompt_lens=prompt_lens_tenser,
prompt_token_ids=token_ids,
prompt_token_ids_cpu=token_ids,
pooling_params=[PoolingParams(task="embed")] * B,
pooling_states=[PoolingStates() for _ in range(B)],
) )
# MLM head (prefer BertMLMHead, fallback to Linear if unavailable) # MLM head (prefer BertMLMHead, fallback to Linear if unavailable)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment