Unverified Commit bc5dd4f6 authored by Pooya Davoodi's avatar Pooya Davoodi Committed by GitHub
Browse files

[Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631)


Signed-off-by: default avatarPooya Davoodi <pooya.davoodi@parasail.io>
parent dbb036cf
...@@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch): ...@@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch):
def server_embedding(): def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend. # GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with pytest.MonkeyPatch.context() as m:
yield remote_server m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server_generate(): def server_generate():
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with pytest.MonkeyPatch.context() as m:
yield remote_server m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def client_embedding(monkeypatch: pytest.MonkeyPatch, async def client_embedding(server_embedding: RemoteOpenAIServer):
server_embedding: RemoteOpenAIServer): async with server_embedding.get_async_client() as async_client:
with monkeypatch.context() as m: yield async_client
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture @pytest_asyncio.fixture
......
...@@ -170,7 +170,8 @@ class GritLMPooler(nn.Module): ...@@ -170,7 +170,8 @@ class GritLMPooler(nn.Module):
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1) 1)
pooled_data = self.head(mean_embeddings) pooled_data = self.head(mean_embeddings,
pooling_metadata=pooling_metadata)
pooled_outputs = [ pooled_outputs = [
PoolingSequenceGroupOutput(data) for data in pooled_data PoolingSequenceGroupOutput(data) for data in pooled_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment