Unverified Commit bc5dd4f6 authored by Pooya Davoodi's avatar Pooya Davoodi Committed by GitHub
Browse files

[Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631)


Signed-off-by: default avatarPooya Davoodi <pooya.davoodi@parasail.io>
parent dbb036cf
......@@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch):
def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
with pytest.MonkeyPatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def server_generate():
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
with pytest.MonkeyPatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
server_embedding: RemoteOpenAIServer):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client:
yield async_client
async def client_embedding(server_embedding: RemoteOpenAIServer):
async with server_embedding.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture
......
......@@ -170,7 +170,8 @@ class GritLMPooler(nn.Module):
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1)
pooled_data = self.head(mean_embeddings)
pooled_data = self.head(mean_embeddings,
pooling_metadata=pooling_metadata)
pooled_outputs = [
PoolingSequenceGroupOutput(data) for data in pooled_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment