mock_mcp_server.rs 5.19 KB
Newer Older
1
// tests/common/mock_mcp_server.rs - Mock MCP server for testing
2
3
4
5
6
7
8
9
10
use rmcp::{
    handler::server::{router::tool::ToolRouter, wrapper::Parameters},
    model::*,
    service::RequestContext,
    tool, tool_handler, tool_router,
    transport::streamable_http_server::{
        session::local::LocalSessionManager, StreamableHttpService,
    },
    ErrorData as McpError, RoleServer, ServerHandler,
11
12
13
14
15
16
17
18
19
};
use tokio::net::TcpListener;

/// Mock MCP server that returns hardcoded responses for testing
pub struct MockMCPServer {
    pub port: u16,
    pub server_handle: Option<tokio::task::JoinHandle<()>>,
}

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/// Simple test server with mock search tools
#[derive(Clone)]
pub struct MockSearchServer {
    tool_router: ToolRouter<MockSearchServer>,
}

#[tool_router]
impl MockSearchServer {
    pub fn new() -> Self {
        Self {
            tool_router: Self::tool_router(),
        }
    }

    #[tool(description = "Mock web search tool")]
    fn brave_web_search(
        &self,
        Parameters(params): Parameters<serde_json::Map<String, serde_json::Value>>,
    ) -> Result<CallToolResult, McpError> {
        let query = params
            .get("query")
            .and_then(|v| v.as_str())
            .unwrap_or("test");
        Ok(CallToolResult::success(vec![Content::text(format!(
            "Mock search results for: {}",
            query
        ))]))
    }

    #[tool(description = "Mock local search tool")]
    fn brave_local_search(
        &self,
        Parameters(_params): Parameters<serde_json::Map<String, serde_json::Value>>,
    ) -> Result<CallToolResult, McpError> {
        Ok(CallToolResult::success(vec![Content::text(
            "Mock local search results",
        )]))
    }
}

#[tool_handler]
impl ServerHandler for MockSearchServer {
    fn get_info(&self) -> ServerInfo {
        ServerInfo {
            protocol_version: ProtocolVersion::V_2024_11_05,
            capabilities: ServerCapabilities::builder().enable_tools().build(),
66
            server_info: Implementation::from_build_env(),
67
68
69
70
71
72
73
74
75
76
77
78
79
            instructions: Some("Mock server for testing".to_string()),
        }
    }

    async fn initialize(
        &self,
        _request: InitializeRequestParam,
        _context: RequestContext<RoleServer>,
    ) -> Result<InitializeResult, McpError> {
        Ok(self.get_info())
    }
}

80
81
82
83
84
85
86
impl MockMCPServer {
    /// Start a mock MCP server on an available port
    pub async fn start() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
        // Find an available port
        let listener = TcpListener::bind("127.0.0.1:0").await?;
        let port = listener.local_addr()?.port();

87
88
89
90
91
92
93
94
        // Create the MCP service using rmcp's StreamableHttpService
        let service = StreamableHttpService::new(
            || Ok(MockSearchServer::new()),
            LocalSessionManager::default().into(),
            Default::default(),
        );

        let app = axum::Router::new().nest_service("/mcp", service);
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

        let server_handle = tokio::spawn(async move {
            axum::serve(listener, app)
                .await
                .expect("Mock MCP server failed to start");
        });

        // Give the server a moment to start
        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;

        Ok(MockMCPServer {
            port,
            server_handle: Some(server_handle),
        })
    }

    /// Get the full URL for this mock server
    pub fn url(&self) -> String {
        format!("http://127.0.0.1:{}/mcp", self.port)
    }

    /// Stop the mock server
    pub async fn stop(&mut self) {
        if let Some(handle) = self.server_handle.take() {
            handle.abort();
            // Wait a moment for cleanup
            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
        }
    }
}

impl Drop for MockMCPServer {
    fn drop(&mut self) {
        if let Some(handle) = self.server_handle.take() {
            handle.abort();
        }
    }
}

#[cfg(test)]
mod tests {
136
    #[allow(unused_imports)]
137
138
139
140
141
142
143
144
145
146
147
    use super::MockMCPServer;

    #[tokio::test]
    async fn test_mock_server_startup() {
        let mut server = MockMCPServer::start().await.unwrap();
        assert!(server.port > 0);
        assert!(server.url().contains(&server.port.to_string()));
        server.stop().await;
    }

    #[tokio::test]
148
    async fn test_mock_server_with_rmcp_client() {
149
        let mut server = MockMCPServer::start().await.unwrap();
150

151
        use rmcp::{transport::StreamableHttpClientTransport, ServiceExt};
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        let transport = StreamableHttpClientTransport::from_uri(server.url().as_str());
        let client = ().serve(transport).await;

        assert!(client.is_ok(), "Should be able to connect to mock server");

        if let Ok(client) = client {
            let tools = client.peer().list_all_tools().await;
            assert!(tools.is_ok(), "Should be able to list tools");

            if let Ok(tools) = tools {
                assert_eq!(tools.len(), 2, "Should have 2 tools");
                assert!(tools.iter().any(|t| t.name == "brave_web_search"));
                assert!(tools.iter().any(|t| t.name == "brave_local_search"));
166
167
            }

168
169
170
            // Shutdown by dropping the client
            drop(client);
        }
171
172
173
174

        server.stop().await;
    }
}