Commit 20247eb8 authored by chenych's avatar chenych
Browse files

Update 0506

parent 6065b946
...@@ -21,7 +21,7 @@ if TYPE_CHECKING: ...@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl"} VALID_MODLE_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3"}
def get_device_flops(unit: str = "T") -> float: def get_device_flops(unit: str = "T") -> float:
......
...@@ -18,9 +18,11 @@ from typing import Optional ...@@ -18,9 +18,11 @@ from typing import Optional
from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin from transformers import AutoProcessor, AutoTokenizer, PreTrainedTokenizer, ProcessorMixin
def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer: def get_tokenizer(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> PreTrainedTokenizer:
"""Create a huggingface pretrained tokenizer.""" """Create a huggingface pretrained tokenizer."""
tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs) tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
if override_chat_template is not None:
tokenizer.chat_template = override_chat_template
if tokenizer.bos_token == "<bos>" and tokenizer.eos_token == "<eos>": if tokenizer.bos_token == "<bos>" and tokenizer.eos_token == "<eos>":
# the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance. # the EOS token in gemma2 & gemma3 is ambiguious, which may worsen RL performance.
...@@ -35,12 +37,11 @@ def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer: ...@@ -35,12 +37,11 @@ def get_tokenizer(model_path: str, **kwargs) -> PreTrainedTokenizer:
return tokenizer return tokenizer
def get_processor(model_path: str, **kwargs) -> Optional[ProcessorMixin]: def get_processor(model_path: str, override_chat_template: Optional[str] = None, **kwargs) -> Optional[ProcessorMixin]:
"""Create a huggingface pretrained processor.""" """Create a huggingface pretrained processor."""
try:
processor = AutoProcessor.from_pretrained(model_path, **kwargs) processor = AutoProcessor.from_pretrained(model_path, **kwargs)
except Exception: if override_chat_template is not None:
processor = None processor.chat_template = override_chat_template
# Avoid load tokenizer, see: # Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344 # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/auto/processing_auto.py#L344
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from .config import RewardConfig from .config import RewardConfig
from .function import FunctionRewardManager from .function import BatchFunctionRewardManager, FunctionRewardManager, SequentialFunctionRewardManager
__all__ = ["FunctionRewardManager", "RewardConfig"] __all__ = ["BatchFunctionRewardManager", "FunctionRewardManager", "RewardConfig", "SequentialFunctionRewardManager"]
...@@ -22,7 +22,7 @@ from typing import Optional ...@@ -22,7 +22,7 @@ from typing import Optional
@dataclass @dataclass
class RewardConfig: class RewardConfig:
reward_type: str = "function" reward_type: str = "batch"
reward_function: Optional[str] = None reward_function: Optional[str] = None
reward_function_kwargs: dict = field(default_factory=dict) reward_function_kwargs: dict = field(default_factory=dict)
skip_special_tokens: bool = True skip_special_tokens: bool = True
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import importlib.util import importlib.util
import os import os
import sys import sys
from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, TypedDict from typing import Callable, Dict, List, Optional, Tuple, TypedDict
...@@ -32,10 +33,12 @@ class RewardScore(TypedDict): ...@@ -32,10 +33,12 @@ class RewardScore(TypedDict):
accuracy: Optional[float] accuracy: Optional[float]
RewardFunction = Callable[[str, str], RewardScore] SequentialRewardFunction = Callable[[str, str], RewardScore]
BatchRewardFunction = Callable[[List[str], List[str]], List[RewardScore]]
class FunctionRewardManager:
class FunctionRewardManager(ABC):
"""Reward manager for rule-based reward.""" """Reward manager for rule-based reward."""
def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer): def __init__(self, config: RewardConfig, tokenizer: PreTrainedTokenizer):
...@@ -56,29 +59,60 @@ class FunctionRewardManager: ...@@ -56,29 +59,60 @@ class FunctionRewardManager:
if not hasattr(module, config.reward_function_name): if not hasattr(module, config.reward_function_name):
raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.") raise AttributeError(f"Module {module} does not have function {config.reward_function_name}.")
reward_fn: RewardFunction = getattr(module, config.reward_function_name) reward_fn = getattr(module, config.reward_function_name)
print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.") print(f"Using reward function `{config.reward_function_name}` from `{config.reward_function}`.")
self.reward_fn = partial(reward_fn, **config.reward_function_kwargs) self.reward_fn = partial(reward_fn, **config.reward_function_kwargs)
self.config = config self.config = config
self.tokenizer = tokenizer self.tokenizer = tokenizer
@abstractmethod
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
"""Compute reward for a batch of data."""
...
class SequentialFunctionRewardManager(FunctionRewardManager):
reward_fn: SequentialRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]: def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32) reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list) reward_metrics = defaultdict(list)
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)): for i in range(len(data)):
data_item = data[i] # DataProtoItem valid_response_ids = response_ids[i][: response_length[i]]
response_ids = data_item.batch["responses"]
response_mask = data_item.batch["response_mask"]
valid_response_length = response_mask.sum()
valid_response_ids = response_ids[:valid_response_length]
response_str = self.tokenizer.decode( response_str = self.tokenizer.decode(
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
) )
ground_truth = data_item.non_tensor_batch["ground_truth"] ground_truth = data.non_tensor_batch["ground_truth"][i]
score = self.reward_fn(response_str, ground_truth) score = self.reward_fn(response_str, ground_truth)
reward_tensor[i, valid_response_length - 1] = score["overall"] reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items():
reward_metrics[key].append(value)
return reward_tensor, reward_metrics
class BatchFunctionRewardManager(FunctionRewardManager):
reward_fn: BatchRewardFunction
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
response_str, ground_truth = [], []
response_ids = data.batch["responses"]
response_length = data.batch["response_mask"].sum(dim=-1)
for i in range(len(data)):
valid_response_ids = response_ids[i][: response_length[i]]
response_str.append(
self.tokenizer.decode(valid_response_ids, skip_special_tokens=self.config.skip_special_tokens)
)
ground_truth.append(data.non_tensor_batch["ground_truth"][i])
scores = self.reward_fn(response_str, ground_truth)
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
reward_metrics = defaultdict(list)
for i, score in enumerate(scores):
reward_tensor[i, response_length[i] - 1] = score["overall"]
for key, value in score.items(): for key, value in score.items():
reward_metrics[key].append(value) reward_metrics[key].append(value)
......
...@@ -84,7 +84,8 @@ class vLLMRollout(BaseRollout): ...@@ -84,7 +84,8 @@ class vLLMRollout(BaseRollout):
limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None, limit_mm_per_prompt={"image": config.limit_images} if config.limit_images > 0 else None,
disable_mm_preprocessor_cache=True, disable_mm_preprocessor_cache=True,
enable_chunked_prefill=config.enable_chunked_prefill, enable_chunked_prefill=config.enable_chunked_prefill,
enable_sleep_mode=False, enable_sleep_mode=False, # only support GPUs
# swap_space=20,
) )
# Offload vllm model to reduce peak memory usage # Offload vllm model to reduce peak memory usage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment