Unverified Commit f9ee6ae1 authored by Jintao Zhang's avatar Jintao Zhang Committed by GitHub
Browse files

[router]: Add Embedding routing logic (#10129)


Signed-off-by: default avatarJintao Zhang <zhangjintao9020@gmail.com>
Co-authored-by: default avatarWaël Boukhobza <wawa_wael@live.fr>
parent dcee42c2
......@@ -155,33 +155,35 @@ jobs:
id: start_servers
run: |
echo "Starting disaggregation servers..."
bash scripts/ci/ci_start_disaggregation_servers.sh &
READY_FILE=".disagg_ready"
rm -f "$READY_FILE"
DISAGG_READY_FILE="$READY_FILE" bash scripts/ci/ci_start_disaggregation_servers.sh &
SERVER_PID=$!
echo "server_pid=$SERVER_PID" >> $GITHUB_OUTPUT
# Wait for all 8 servers to be healthy (script already does this)
wait_count=0
while [ $wait_count -lt 30 ]; do
if ps -p $SERVER_PID > /dev/null; then
# Check if the startup script printed success message
sleep 2
wait_count=$((wait_count + 1))
else
# Script exited - check if it was successful
wait $SERVER_PID
exit_code=$?
if [ $exit_code -eq 0 ]; then
echo "✓ All disaggregation servers are healthy"
break
else
echo "Error: Server startup failed with code $exit_code"
exit 1
fi
# Wait until script signals readiness (8/8 healthy) or timeout
TIMEOUT=300
ELAPSED=0
while [ $ELAPSED -lt $TIMEOUT ]; do
if [ -f "$READY_FILE" ]; then
echo "✓ All disaggregation servers are healthy (signal detected)"
break
fi
if ! ps -p $SERVER_PID > /dev/null; then
echo "Error: server bootstrap script exited prematurely"
exit 1
fi
sleep 5
ELAPSED=$((ELAPSED + 5))
done
if [ $ELAPSED -ge $TIMEOUT ]; then
echo "❌ Timeout waiting for disaggregation servers to be healthy"
exit 1
fi
echo "✓ Servers started (PID: $SERVER_PID)"
- name: Test all policies sequentially
timeout-minutes: 30
run: |
......
#!/bin/bash
set -euo pipefail
# Optional: set DISAGG_READY_FILE to a filepath; when all servers are healthy, the script will
# create this file as a readiness signal (useful for CI to proceed to next steps).
DISAGG_READY_FILE="${DISAGG_READY_FILE:-}"
MODEL_PATH="/raid/models/meta-llama/Llama-3.1-8B-Instruct"
......@@ -81,6 +86,13 @@ while true; do
if [ $HEALTHY_COUNT -eq 8 ]; then
echo "✅ All 8 servers are healthy!"
# Emit readiness signal file if requested
if [ -n "$DISAGG_READY_FILE" ]; then
echo "Creating readiness flag: $DISAGG_READY_FILE"
# Ensure parent dir exists; ignore errors
mkdir -p "$(dirname "$DISAGG_READY_FILE")" 2>/dev/null || true
touch "$DISAGG_READY_FILE"
fi
break
else
sleep 10 # Wait 10 seconds before next check
......
......@@ -715,6 +715,29 @@ def e2e_router_only_rr():
_terminate(proc)
@pytest.fixture(scope="session")
def e2e_embedding_model() -> str:
"""Embedding model to use for E2E tests.
Defaults to an E5 Mistral model, can be overridden via E2E_EMBEDDING_MODEL env var.
"""
import os
return os.getenv("E2E_EMBEDDING_MODEL", "intfloat/e5-mistral-7b-instruct")
@pytest.fixture
def e2e_primary_embedding_worker(e2e_embedding_model: str):
"""Launch a single embedding worker using the specified model."""
port = _find_available_port()
base_url = f"http://127.0.0.1:{port}"
proc = _popen_launch_worker(e2e_embedding_model, base_url)
try:
yield SimpleNamespace(proc=proc, url=base_url)
finally:
_terminate(proc)
@pytest.fixture(scope="session")
def e2e_primary_worker(e2e_model: str):
port = _find_available_port()
......
from types import SimpleNamespace
import pytest
import requests
@pytest.mark.e2e
def test_embeddings_basic(
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
):
base = e2e_router_only_rr.url
worker_url = e2e_primary_embedding_worker.url
# Attach embedding worker to router-only instance
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
r.raise_for_status()
# Simple embedding request with two inputs
payload = {
"model": e2e_embedding_model,
"input": [
"the quick brown fox",
"jumps over the lazy dog",
],
}
r = requests.post(f"{base}/v1/embeddings", json=payload, timeout=120)
assert r.status_code == 200, f"unexpected status: {r.status_code} {r.text}"
data = r.json()
assert "data" in data and isinstance(data["data"], list)
assert len(data["data"]) == 2
# Validate shape of embedding objects
for item in data["data"]:
assert "embedding" in item and isinstance(item["embedding"], list)
# Ensure non-empty vectors
assert len(item["embedding"]) > 0
......@@ -143,6 +143,18 @@ pub fn init_metrics() {
"Generate request duration"
);
// Embedding request specific metrics
describe_counter!("sgl_router_embeddings_total", "Total embedding requests");
describe_histogram!(
"sgl_router_embeddings_duration_seconds",
"Embedding request duration"
);
describe_counter!(
"sgl_router_embeddings_errors_total",
"Embedding request errors"
);
describe_gauge!("sgl_router_embeddings_queue_size", "Embedding queue size");
// Running requests gauge for cache-aware policy
describe_gauge!(
"sgl_router_running_requests",
......@@ -440,6 +452,27 @@ impl RouterMetrics {
histogram!("sgl_router_generate_duration_seconds").record(duration.as_secs_f64());
}
// Embeddings metrics
pub fn record_embeddings_request() {
counter!("sgl_router_embeddings_total").increment(1);
}
pub fn record_embeddings_duration(duration: Duration) {
histogram!("sgl_router_embeddings_duration_seconds").record(duration.as_secs_f64());
}
pub fn record_embeddings_error(error_type: &str) {
counter!(
"sgl_router_embeddings_errors_total",
"error_type" => error_type.to_string()
)
.increment(1);
}
pub fn set_embeddings_queue_size(size: usize) {
gauge!("sgl_router_embeddings_queue_size").set(size as f64);
}
// Running requests for cache-aware policy
pub fn set_running_requests(worker: &str, count: usize) {
gauge!("sgl_router_running_requests",
......
......@@ -3,6 +3,7 @@ use axum::{
response::IntoResponse, response::Response,
};
use rand::Rng;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
......@@ -13,6 +14,7 @@ use tracing::{debug, error, field::Empty, info, info_span, warn, Span};
pub use crate::core::token_bucket::TokenBucket;
use crate::metrics::RouterMetrics;
use crate::server::AppState;
/// Generate OpenAI-compatible request ID based on endpoint
......@@ -441,6 +443,11 @@ pub async fn concurrency_limit_middleware(
request: Request<axum::body::Body>,
next: Next,
) -> Response {
// Static counter for embeddings queue size
static EMBEDDINGS_QUEUE_SIZE: AtomicU64 = AtomicU64::new(0);
// Identify if this is an embeddings request based on path
let is_embeddings = request.uri().path().contains("/v1/embeddings");
let token_bucket = app_state.context.rate_limiter.clone();
// Try to acquire token immediately
......@@ -468,10 +475,23 @@ pub async fn concurrency_limit_middleware(
// Try to send to queue
match queue_tx.try_send(queued) {
Ok(_) => {
// On successful enqueue, update embeddings queue gauge if applicable
if is_embeddings {
let new_val = EMBEDDINGS_QUEUE_SIZE.fetch_add(1, Ordering::Relaxed) + 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
// Wait for token from queue processor
match permit_rx.await {
Ok(Ok(())) => {
debug!("Acquired token from queue");
// Dequeue for embeddings
if is_embeddings {
let new_val =
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
let response = next.run(request).await;
// Return the token to the bucket
......@@ -481,10 +501,22 @@ pub async fn concurrency_limit_middleware(
}
Ok(Err(status)) => {
warn!("Queue returned error status: {}", status);
// Dequeue for embeddings on error
if is_embeddings {
let new_val =
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
status.into_response()
}
Err(_) => {
error!("Queue response channel closed");
// Dequeue for embeddings on channel error
if is_embeddings {
let new_val =
EMBEDDINGS_QUEUE_SIZE.fetch_sub(1, Ordering::Relaxed) - 1;
RouterMetrics::set_embeddings_queue_size(new_val as usize);
}
StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
......
......@@ -41,7 +41,10 @@ use std::collections::HashMap;
// 6. **SGLANG SPEC - RERANK API**
// - Request/Response structures
//
// 7. **COMMON**
// 7. **OPENAI SPEC - Embeddings API**
// - Request structures
//
// 8. **COMMON**
// - GenerationRequest trait
// - StringOrArray & LoRAPath types
// - Helper functions
......@@ -2013,6 +2016,61 @@ impl RerankResponse {
}
}
// ==================================================================
// = OPENAI SPEC - Embeddings API =
// ==================================================================
/// Embeddings request compatible with OpenAI API
/// We intentionally keep fields flexible to pass through to workers.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct EmbeddingRequest {
/// ID of the model to use
pub model: String,
/// Input can be a string, array of strings, tokens, or batch inputs
pub input: serde_json::Value,
/// Optional encoding format (e.g., "float", "base64")
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<String>,
/// Optional user identifier
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// Optional number of dimensions for the embedding
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
/// SGLang extension: request id for tracking
#[serde(skip_serializing_if = "Option::is_none")]
pub rid: Option<String>,
}
impl GenerationRequest for EmbeddingRequest {
fn is_stream(&self) -> bool {
// Embeddings are non-streaming
false
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
// Best effort: extract text content for routing decisions
match &self.input {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
.join(" "),
_ => String::new(),
}
}
}
// ==================================================================
// = COMMON =
// ==================================================================
......@@ -2715,4 +2773,102 @@ mod tests {
assert_eq!(deserialized.results.len(), 2);
assert_eq!(deserialized.model, response.model);
}
// ==================================================================
// = EMBEDDINGS REQUEST TESTS =
// ==================================================================
#[test]
fn test_embedding_request_serialization_string_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: serde_json::Value::String("hello".to_string()),
encoding_format: Some("float".to_string()),
user: Some("user-1".to_string()),
dimensions: Some(128),
rid: Some("rid-123".to_string()),
};
let serialized = serde_json::to_string(&req).unwrap();
let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.model, req.model);
assert_eq!(deserialized.input, req.input);
assert_eq!(deserialized.encoding_format, req.encoding_format);
assert_eq!(deserialized.user, req.user);
assert_eq!(deserialized.dimensions, req.dimensions);
assert_eq!(deserialized.rid, req.rid);
}
#[test]
fn test_embedding_request_serialization_array_input() {
let req = EmbeddingRequest {
model: "test-emb".to_string(),
input: serde_json::json!(["a", "b", "c"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
let serialized = serde_json::to_string(&req).unwrap();
let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
assert_eq!(de.model, req.model);
assert_eq!(de.input, req.input);
}
#[test]
fn test_embedding_generation_request_trait_string() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::Value::String("hello".to_string()),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert!(!req.is_stream());
assert_eq!(req.get_model(), Some("emb-model"));
assert_eq!(req.extract_text_for_routing(), "hello");
}
#[test]
fn test_embedding_generation_request_trait_array() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::json!(["hello", "world"]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "hello world");
}
#[test]
fn test_embedding_generation_request_trait_non_text() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::json!({"tokens": [1, 2, 3]}),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
assert_eq!(req.extract_text_for_routing(), "");
}
#[test]
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
let req = EmbeddingRequest {
model: "emb-model".to_string(),
input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]),
encoding_format: None,
user: None,
dimensions: None,
rid: None,
};
// Only top-level string elements are extracted
assert_eq!(req.extract_text_for_routing(), "a");
}
}
......@@ -309,7 +309,12 @@ impl RouterTrait for GrpcPDRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......
......@@ -242,7 +242,12 @@ impl RouterTrait for GrpcRouter {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(StatusCode::NOT_IMPLEMENTED).into_response()
}
......
......@@ -395,7 +395,12 @@ impl super::super::RouterTrait for OpenAIRouter {
}
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Embeddings endpoint not implemented for OpenAI backend",
......
......@@ -1938,8 +1938,17 @@ impl RouterTrait for PDRouter {
.into_response()
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &crate::protocols::spec::EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Embeddings endpoint not implemented for PD router",
)
.into_response()
}
async fn route_rerank(
......
......@@ -6,8 +6,8 @@ use crate::core::{
use crate::metrics::RouterMetrics;
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest,
RerankResponse, RerankResult, ResponsesRequest,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, GenerationRequest,
RerankRequest, RerankResponse, RerankResult, ResponsesRequest,
};
use crate::routers::header_utils;
use crate::routers::{RouterTrait, WorkerManagement};
......@@ -1430,8 +1430,28 @@ impl RouterTrait for Router {
self.route_post_empty_request(headers, &endpoint).await
}
async fn route_embeddings(&self, _headers: Option<&HeaderMap>, _body: Body) -> Response {
todo!()
async fn route_embeddings(
&self,
headers: Option<&HeaderMap>,
body: &EmbeddingRequest,
model_id: Option<&str>,
) -> Response {
// Record embeddings-specific metrics in addition to general request metrics
let start = Instant::now();
let res = self
.route_typed_request(headers, body, "/v1/embeddings", model_id)
.await;
// Embedding specific metrics
if res.status().is_success() {
RouterMetrics::record_embeddings_request();
RouterMetrics::record_embeddings_duration(start.elapsed());
} else {
let error_type = format!("http_{}", res.status().as_u16());
RouterMetrics::record_embeddings_error(&error_type);
}
res
}
async fn route_rerank(
......
......@@ -10,7 +10,8 @@ use axum::{
use std::fmt::Debug;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest,
};
pub mod factory;
......@@ -123,7 +124,13 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
.into_response()
}
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
/// Route embedding requests (OpenAI-compatible /v1/embeddings)
async fn route_embeddings(
&self,
headers: Option<&HeaderMap>,
body: &EmbeddingRequest,
model_id: Option<&str>,
) -> Response;
async fn route_rerank(
&self,
......
......@@ -7,7 +7,8 @@
use crate::config::RouterConfig;
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest,
};
use crate::protocols::worker_spec::{
ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo,
......@@ -665,22 +666,6 @@ impl RouterTrait for RouterManager {
.into_response()
}
async fn get_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn cancel_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
......@@ -701,17 +686,51 @@ impl RouterTrait for RouterManager {
.into_response()
}
/// Route embeddings request
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response {
// Try to select a router based on headers
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.get_response(headers, response_id).await
} else {
(
StatusCode::NOT_FOUND,
format!("No router available to get response '{}'", response_id),
)
.into_response()
}
}
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.cancel_response(headers, response_id).await
} else {
(
StatusCode::NOT_FOUND,
format!("No router available to cancel response '{}'", response_id),
)
.into_response()
}
}
/// Route embeddings request
async fn route_embeddings(
&self,
headers: Option<&HeaderMap>,
body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
// Select router based on headers and model
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router.route_embeddings(headers, body).await
router
.route_embeddings(headers, body, Some(&body.model))
.await
} else {
// Return 404 when the specified model is not found
(
StatusCode::NOT_FOUND,
"No router available for embeddings request",
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
......
......@@ -5,8 +5,8 @@ use crate::metrics::{self, PrometheusConfig};
use crate::middleware::TokenBucket;
use crate::policies::PolicyRegistry;
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
V1RerankReqInput,
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponsesRequest, V1RerankReqInput,
};
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
use crate::reasoning_parser::ParserFactory;
......@@ -208,6 +208,17 @@ async fn v1_responses(
.await
}
async fn v1_embeddings(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<EmbeddingRequest>,
) -> Response {
state
.router
.route_embeddings(Some(&headers), &body, None)
.await
}
async fn v1_responses_get(
State(state): State<Arc<AppState>>,
Path(response_id): Path<String>,
......@@ -465,6 +476,7 @@ pub fn build_app(
.route("/rerank", post(rerank))
.route("/v1/rerank", post(v1_rerank))
.route("/v1/responses", post(v1_responses))
.route("/v1/embeddings", post(v1_embeddings))
.route("/v1/responses/{response_id}", get(v1_responses_get))
.route(
"/v1/responses/{response_id}/cancel",
......
......@@ -1090,10 +1090,14 @@ mod responses_endpoint_tests {
let app = ctx.create_app().await;
// First create a response to obtain an id
let resp_id = "test-get-resp-id-123";
let payload = json!({
"input": "Hello Responses API",
"model": "mock-model",
"stream": false
"stream": false,
"store": true,
"background": true,
"request_id": resp_id
});
let req = Request::builder()
.method("POST")
......@@ -1103,11 +1107,6 @@ mod responses_endpoint_tests {
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let resp_id = body_json["id"].as_str().unwrap().to_string();
// Retrieve the response
let req = Request::builder()
......@@ -1140,10 +1139,14 @@ mod responses_endpoint_tests {
let app = ctx.create_app().await;
// First create a response to obtain an id
let resp_id = "test-cancel-resp-id-456";
let payload = json!({
"input": "Hello Responses API",
"model": "mock-model",
"stream": false
"stream": false,
"store": true,
"background": true,
"request_id": resp_id
});
let req = Request::builder()
.method("POST")
......@@ -1153,11 +1156,6 @@ mod responses_endpoint_tests {
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let resp_id = body_json["id"].as_str().unwrap().to_string();
// Cancel the response
let req = Request::builder()
......
......@@ -20,7 +20,12 @@ import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
from sglang.test.test_utils import (
CustomTestCase,
get_similarities,
is_in_amd_ci,
is_in_ci,
)
MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
......@@ -74,11 +79,13 @@ class TestEmbeddingModels(CustomTestCase):
) as hf_runner:
hf_outputs = hf_runner.forward(truncated_prompts)
attention_backend = "triton" if is_in_amd_ci() else None
with SRTRunner(
model_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="embedding",
attention_backend=attention_backend,
) as srt_runner:
srt_outputs = srt_runner.forward(truncated_prompts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment