state_management.py 4.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
State management tests for Response API.

Tests both previous_response_id and conversation-based state management.
These tests should work across all backends (OpenAI, XAI, gRPC).
"""

import unittest

from base import ResponseAPIBaseTest


class StateManagementTests(ResponseAPIBaseTest):
    """Tests for state management using previous_response_id and conversation."""

    def test_previous_response_id_chaining(self):
        """Test chaining responses using previous_response_id."""
        # First response
        resp1 = self.create_response(
            "My name is Alice and my friend is Bob. Remember it."
        )
        self.assertEqual(resp1.status_code, 200)
        response1_id = resp1.json()["id"]

        # Second response referencing first
        resp2 = self.create_response(
            "What is my name", previous_response_id=response1_id
        )
        self.assertEqual(resp2.status_code, 200)
        response2_data = resp2.json()

        # The model should remember the name from previous response
        output_text = self._extract_output_text(response2_data)
        self.assertIn("Alice", output_text)

        # Third response referencing second
        resp3 = self.create_response(
            "What is my friend name?",
            previous_response_id=response2_data["id"],
        )
        response3_data = resp3.json()
        output_text = self._extract_output_text(response3_data)
        self.assertEqual(resp3.status_code, 200)
        self.assertIn("Bob", output_text)

    @unittest.skip("TODO: Add the invalid previous_response_id check")
    def test_previous_response_id_invalid(self):
        """Test using invalid previous_response_id."""
        resp = self.create_response(
            "Test", previous_response_id="resp_invalid123", max_output_tokens=50
        )
        self.assertIn(resp.status_code, [400, 404])

    def test_conversation_with_multiple_turns(self):
        """Test state management using conversation ID."""
        # Create conversation
        conv_resp = self.create_conversation(metadata={"topic": "math"})
        self.assertEqual(conv_resp.status_code, 200)

        conversation_id = conv_resp.json()["id"]

        # First response in conversation
        resp1 = self.create_response("I have 5 apples.", conversation=conversation_id)
        self.assertEqual(resp1.status_code, 200)

        # Second response in same conversation
        resp2 = self.create_response(
            "How many apples do I have?",
            conversation=conversation_id,
        )
        self.assertEqual(resp2.status_code, 200)
        output_text = self._extract_output_text(resp2.json())

        # Should remember "5 apples"
        self.assertTrue("5" in output_text or "five" in output_text.lower())

        # Third response in same conversation
        resp3 = self.create_response(
            "If I get 3 more, how many total?",
            conversation=conversation_id,
        )
        self.assertEqual(resp3.status_code, 200)
        output_text = self._extract_output_text(resp3.json())

        # Should calculate 5 + 3 = 8
        self.assertTrue("8" in output_text or "eight" in output_text.lower())
        list_resp = self.list_conversation_items(conversation_id)
        self.assertEqual(list_resp.status_code, 200)
        items = list_resp.json()["data"]
        # Should have at least 6 items (3 inputs + 3 outputs)
        self.assertGreaterEqual(len(items), 6)

    def test_mutually_exclusive_parameters(self):
        """Test that previous_response_id and conversation are mutually exclusive."""
        # Create conversation and response
        conv_resp = self.create_conversation()
        conversation_id = conv_resp.json()["id"]

        resp1 = self.create_response("Test")
        response1_id = resp1.json()["id"]

        # Try to use both parameters
        resp = self.create_response(
            "This should fail",
            previous_response_id=response1_id,
            conversation=conversation_id,
        )

        # Should return 400 Bad Request
        self.assertEqual(resp.status_code, 400)
        error_data = resp.json()
        self.assertIn("error", error_data)
        self.assertIn("mutually exclusive", error_data["error"]["message"].lower())

    # Helper methods

    def _extract_output_text(self, response_data: dict) -> str:
        """Extract text content from response output."""
        output = response_data.get("output", [])
        if not output:
            return ""

        text_parts = []
        for item in output:
            content = item.get("content", [])
            for part in content:
                if part.get("type") == "output_text":
                    text_parts.append(part.get("text", ""))

        return " ".join(text_parts)