Commit 20247eb8 authored by chenych's avatar chenych
Browse files

Update 0506

parent 6065b946
......@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"}
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"}
def get_device_flops(unit: str = "T") -> float:
......
......@@ -18,9 +18,11 @@ from typing import Optional
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> PreTrainedTokenizer:
"""Create a huggingface pretrained tokenizer."""
tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
if override_chat_template is not None:
tokenizer.chat_template = override_chat_template
if tokenizer.bos_token == "<bos>" and tokenizer.eos_token == "<eos>":
# the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
......@@ -35,12 +37,11 @@ def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
return tokenizer
def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]:
def get_processor(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> Optional[ProcessorMixin]:
"""Create a huggingface pretrained processor."""
try:
processor = AutoProcessor.from_pretrained(model_path, **kwargs)
except Exception:
processor = None
if override_chat_template is not None:
processor.chat_template = override_chat_template
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
......
......@@ -13,7 +13,7 @@
# limitations under the License.
from .config import RewardConfig
from .function import FunctionRewardManager
from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager
__all__ = ["FunctionRewardManager", "RewardConfig"]
__all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"]
......@@ -22,7 +22,7 @@ from typing import Optional
@dataclass
class RewardConfig:
reward_type: str = "function"
reward_type: str = "batch"
reward_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True
......
......@@ -15,6 +15,7 @@
import importlib.util
import os
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict
......@@ -32,10 +33,12 @@ class RewardScore(TypedDict):
accuracy: Optional[float]
RewardFunction = Callable[[str, str], RewardScore]
SequentialRewardFunction = Callable[[str, str], RewardScore]
BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]]
class FunctionRewardManager:
class FunctionRewardManager(ABC):
"""Reward manager for rule-based reward."""
def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
......@@ -56,29 +59,60 @@ class FunctionRewardManager:
if not hasattr(module, config.reward_function_name):
raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
reward_fn: RewardFunction = getattr(module, config.reward_function_name)
reward_fn = getattr(module, config.reward_function_name)
print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
self.config = config
self.tokenizer = tokenizer
@abstractmethod
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
"""Compute reward for a batch of data."""
...
class SequentialFunctionRewardManager(FunctionRewardManager):
reward_fn: SequentialRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
data_item = data[i] # DataProtoItem
response_ids = data_item.batch["responses"]
response_mask = data_item.batch["response_mask"]
valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length]
valid_response_ids = response_ids[i][: response_length[i]]
response_str = self.tokenizer.decode(
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
)
ground_truth = data_item.non_tensor_batch["ground_truth"]
ground_truth = data.non_tensor_batch["ground_truth"][i]
score = self.reward_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"]
reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
return reward_tensor, reward_metrics
class BatchFunctionRewardManager(FunctionRewardManager):
reward_fn: BatchRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
response_str, ground_truth = [], []
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
valid_response_ids = response_ids[i][: response_length[i]]
response_str.append(
self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens)
)
ground_truth.append(data.non_tensor_batch["ground_truth"][i])
scores = self.reward_fn(response_str, ground_truth)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
for i, score in enumerate(scores):
reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
......
......@@ -84,7 +84,8 @@ class vLLMRollout(BaseRollout):
limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
disable_mm_preprocessor_cache=True,
enable_chunked_prefill=config.enable_chunked_prefill,
enable_sleep_mode=False,
enable_sleep_mode=False, # only support GPUs
# swap_space=20,
)
# Offload vllm model to reduce peak memory usage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment