Unverified Commit d809906e authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add embedding support to sgl backend (#3427)


Signed-off-by: default avatarkrishung5 <krish@nvidia.com>
parent 4c888bf4
...@@ -149,6 +149,47 @@ class CompletionPayload(BasePayload): ...@@ -149,6 +149,47 @@ class CompletionPayload(BasePayload):
return CompletionPayload.extract_text(response) return CompletionPayload.extract_text(response)
@dataclass
class EmbeddingPayload(BasePayload):
"""Payload for embeddings endpoint."""
endpoint: str = "/v1/embeddings"
@staticmethod
def extract_embeddings(response):
"""
Process embeddings API responses.
"""
response.raise_for_status()
result = response.json()
assert "object" in result, "Missing 'object' in response"
assert (
result["object"] == "list"
), f"Expected object='list', got {result['object']}"
assert "data" in result, "Missing 'data' in response"
assert len(result["data"]) > 0, "Empty data in response"
# Extract embedding vectors and validate structure
embeddings = []
for item in result["data"]:
assert "object" in item, "Missing 'object' in embedding item"
assert (
item["object"] == "embedding"
), f"Expected object='embedding', got {item['object']}"
assert "embedding" in item, "Missing 'embedding' vector in item"
assert isinstance(
item["embedding"], list
), "Embedding should be a list of floats"
assert len(item["embedding"]) > 0, "Embedding vector should not be empty"
embeddings.append(item["embedding"])
# Return a summary string for validation
return f"Generated {len(embeddings)} embeddings with dimension {len(embeddings[0])}"
def response_handler(self, response: Any) -> str:
return EmbeddingPayload.extract_embeddings(response)
@dataclass @dataclass
class MetricsPayload(BasePayload): class MetricsPayload(BasePayload):
endpoint: str = "/metrics" endpoint: str = "/metrics"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment