# Copyright 2024 Bytedance Ltd. and/or its affiliates # Copyright 2023-2024 SGLang Team # Copyright 2025 ModelBest Inc. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import os from typing import Any, Optional from uuid import uuid4 from verl.utils.reward_score import gsm8k from .base import BaseInteraction logger = logging.getLogger(__name__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class Gsm8kInteraction(BaseInteraction): """A demo interaction for calculating the reward of gsm8k. - `start_interaction`: start a interaction instance for a trajectory. - `generate_response`: generate the response of the user. - `calculate_score`: calculate the score of the interaction. - `finalize_interaction`: finalize the interaction instance. """ def __init__(self, config: dict): super().__init__(config) self._instance_dict = {} async def start_interaction( self, instance_id: Optional[str] = None, ground_truth: Optional[str] = None, **kwargs ) -> str: if instance_id is None: instance_id = str(uuid4()) self._instance_dict[instance_id] = { "response": "", "ground_truth": ground_truth, "reward": 0.0, } return instance_id async def generate_response( self, instance_id: str, messages: list[dict[str, Any]], **kwargs ) -> tuple[bool, str, float, dict]: content = "" for i in range(len(messages) - 1, -1, -1): item = messages[i] if item.get("role") == "assistant": content = item.get("content") break self._instance_dict[instance_id]["response"] = content reward = await self.calculate_score(instance_id) if reward == 1.0: response = "Your response is correct!" should_terminate_sequence = True else: response = "Your response is incorrect! You need to reflect on your answer and try again." should_terminate_sequence = False return should_terminate_sequence, response, reward, {} async def calculate_score(self, instance_id: str, **kwargs) -> float: return gsm8k.compute_score( self._instance_dict[instance_id]["response"], self._instance_dict[instance_id]["ground_truth"], method="strict", format_score=0.0, score=1.0, ) async def finalize_interaction(self, instance_id: str, **kwargs) -> None: del self._instance_dict[instance_id]