Commit 7f6cc211 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2874 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import pytest
from verl.interactions.gsm8k_interaction import Gsm8kInteraction
class TestGsm8kInteraction:
"""Test cases for Gsm8kInteraction class."""
def setup_method(self):
"""Set up test environment before each test method."""
self.config = {"name": "gsm8k"}
self.interaction = Gsm8kInteraction(self.config)
def test_init(self):
"""Test Gsm8kInteraction initialization."""
assert self.interaction._instance_dict == {}
assert self.interaction.config == self.config
assert self.interaction.name == "gsm8k"
@pytest.mark.asyncio
async def test_start_interaction_with_instance_id(self):
"""Test start_interaction with provided instance_id."""
instance_id = "test_instance"
ground_truth = "42"
result_id = await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
assert result_id == instance_id
assert instance_id in self.interaction._instance_dict
assert self.interaction._instance_dict[instance_id]["response"] == ""
assert self.interaction._instance_dict[instance_id]["ground_truth"] == ground_truth
assert self.interaction._instance_dict[instance_id]["reward"] == 0.0
@pytest.mark.asyncio
async def test_start_interaction_without_instance_id(self):
"""Test start_interaction without provided instance_id (auto-generated)."""
ground_truth = "42"
result_id = await self.interaction.start_interaction(ground_truth=ground_truth)
assert result_id is not None
assert len(result_id) == 36 # UUID4 length
assert result_id in self.interaction._instance_dict
assert self.interaction._instance_dict[result_id]["ground_truth"] == ground_truth
@pytest.mark.asyncio
async def test_start_interaction_without_ground_truth(self):
"""Test start_interaction without ground_truth parameter."""
instance_id = "test_instance"
result_id = await self.interaction.start_interaction(instance_id=instance_id)
assert result_id == instance_id
assert self.interaction._instance_dict[instance_id]["ground_truth"] is None
@pytest.mark.asyncio
async def test_generate_response_correct_answer_with_prefix(self):
"""Test generate_response with correct answer already having #### prefix."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "user", "content": "#### 42"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is True
assert response == "Your response is correct!"
assert reward == 1.0
assert metadata == {}
assert self.interaction._instance_dict[instance_id]["response"] == "#### 42"
@pytest.mark.asyncio
async def test_generate_response_correct_answer_without_prefix(self):
"""Test generate_response with correct answer missing #### prefix."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "user", "content": "42"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is True
assert response == "Your response is correct!"
assert reward == 1.0
assert self.interaction._instance_dict[instance_id]["response"] == "#### 42"
@pytest.mark.asyncio
async def test_generate_response_incorrect_answer(self):
"""Test generate_response with incorrect answer."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "user", "content": "24"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is False
assert response == "Your response is incorrect! You need to reflect on your answer and try again."
assert reward == 0.0
assert self.interaction._instance_dict[instance_id]["response"] == "#### 24"
@pytest.mark.asyncio
async def test_generate_response_multiple_messages(self):
"""Test generate_response with multiple messages (should use last user message)."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "Let me think about this..."},
{"role": "user", "content": "#### 42"},
]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is True
assert response == "Your response is correct!"
assert self.interaction._instance_dict[instance_id]["response"] == "#### 42"
@pytest.mark.asyncio
async def test_generate_response_no_user_message(self):
"""Test generate_response with no user messages."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [{"role": "assistant", "content": "Hello!"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is False
assert self.interaction._instance_dict[instance_id]["response"] == "#### "
@pytest.mark.asyncio
async def test_calculate_score_direct_call(self):
"""Test calculate_score method directly."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
# Set a response
self.interaction._instance_dict[instance_id]["response"] = "#### 42"
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0) as mock_compute:
score = await self.interaction.calculate_score(instance_id)
assert score == 1.0
mock_compute.assert_called_once_with("#### 42", "42", method="flexible", format_score=0.0, score=1.0)
@pytest.mark.asyncio
async def test_calculate_score_with_kwargs(self):
"""Test calculate_score method with additional kwargs."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
# Set a response
self.interaction._instance_dict[instance_id]["response"] = "#### 24"
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0) as mock_compute:
score = await self.interaction.calculate_score(instance_id, extra_param="test")
assert score == 0.0
mock_compute.assert_called_once_with("#### 24", "42", method="flexible", format_score=0.0, score=1.0)
@pytest.mark.asyncio
async def test_finalize_interaction(self):
"""Test finalize_interaction method."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
assert instance_id in self.interaction._instance_dict
await self.interaction.finalize_interaction(instance_id)
assert instance_id not in self.interaction._instance_dict
@pytest.mark.asyncio
async def test_finalize_interaction_with_kwargs(self):
"""Test finalize_interaction method with additional kwargs."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
assert instance_id in self.interaction._instance_dict
await self.interaction.finalize_interaction(instance_id, extra_param="test")
assert instance_id not in self.interaction._instance_dict
@pytest.mark.asyncio
async def test_finalize_nonexistent_interaction(self):
"""Test finalize_interaction with non-existent instance_id."""
instance_id = "nonexistent_instance"
# This should raise KeyError
with pytest.raises(KeyError):
await self.interaction.finalize_interaction(instance_id)
@pytest.mark.asyncio
async def test_full_interaction_workflow_correct(self):
"""Test complete interaction workflow with correct answer."""
ground_truth = "42"
# Start interaction
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
# Generate response with correct answer
messages = [{"role": "user", "content": "42"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is True
assert reward == 1.0
# Finalize interaction
await self.interaction.finalize_interaction(instance_id)
assert instance_id not in self.interaction._instance_dict
@pytest.mark.asyncio
async def test_full_interaction_workflow_incorrect(self):
"""Test complete interaction workflow with incorrect answer."""
ground_truth = "42"
# Start interaction
instance_id = await self.interaction.start_interaction(ground_truth=ground_truth)
# Generate response with incorrect answer
messages = [{"role": "user", "content": "24"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is False
assert reward == 0.0
# Continue with another attempt
messages.append({"role": "assistant", "content": response})
messages.append({"role": "user", "content": "42"})
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=1.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is True
assert reward == 1.0
# Finalize interaction
await self.interaction.finalize_interaction(instance_id)
assert instance_id not in self.interaction._instance_dict
@pytest.mark.asyncio
async def test_multiple_concurrent_interactions(self):
"""Test multiple concurrent interaction instances."""
ground_truth_1 = "42"
ground_truth_2 = "24"
# Start multiple interactions
instance_id_1 = await self.interaction.start_interaction(ground_truth=ground_truth_1)
instance_id_2 = await self.interaction.start_interaction(ground_truth=ground_truth_2)
assert len(self.interaction._instance_dict) == 2
assert instance_id_1 in self.interaction._instance_dict
assert instance_id_2 in self.interaction._instance_dict
# Test responses for both instances
messages_1 = [{"role": "user", "content": "42"}]
messages_2 = [{"role": "user", "content": "24"}]
with patch("verl.utils.reward_score.gsm8k.compute_score", side_effect=[1.0, 1.0]):
should_terminate_1, _, reward_1, _ = await self.interaction.generate_response(instance_id_1, messages_1)
should_terminate_2, _, reward_2, _ = await self.interaction.generate_response(instance_id_2, messages_2)
assert should_terminate_1 is True
assert should_terminate_2 is True
assert reward_1 == 1.0
assert reward_2 == 1.0
# Finalize both interactions
await self.interaction.finalize_interaction(instance_id_1)
await self.interaction.finalize_interaction(instance_id_2)
assert len(self.interaction._instance_dict) == 0
@pytest.mark.asyncio
async def test_edge_case_empty_messages(self):
"""Test edge case with empty messages list."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = []
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is False
assert reward == 0.0
assert self.interaction._instance_dict[instance_id]["response"] == "#### "
@pytest.mark.asyncio
async def test_edge_case_message_without_content(self):
"""Test edge case with message without content field."""
instance_id = "test_instance"
ground_truth = "42"
# Setup instance
await self.interaction.start_interaction(instance_id=instance_id, ground_truth=ground_truth)
messages = [
{"role": "user"} # Missing content field
]
with patch("verl.utils.reward_score.gsm8k.compute_score", return_value=0.0):
should_terminate, response, reward, metadata = await self.interaction.generate_response(
instance_id, messages
)
assert should_terminate is False
assert reward == 0.0
assert self.interaction._instance_dict[instance_id]["response"] == "#### None"
def test_inheritance_from_base_interaction(self):
"""Test that Gsm8kInteraction properly inherits from BaseInteraction."""
from verl.interactions.base import BaseInteraction
assert isinstance(self.interaction, BaseInteraction)
# Test that all required methods are implemented
assert hasattr(self.interaction, "start_interaction")
assert hasattr(self.interaction, "generate_response")
assert hasattr(self.interaction, "calculate_score")
assert hasattr(self.interaction, "finalize_interaction")
# Test that methods are callable
assert callable(self.interaction.start_interaction)
assert callable(self.interaction.generate_response)
assert callable(self.interaction.calculate_score)
assert callable(self.interaction.finalize_interaction)
def test_name_attribute_initialization(self):
"""Test name attribute initialization with different configs."""
# Test with explicit name in config
config_with_name = {"name": "custom_gsm8k"}
interaction_with_name = Gsm8kInteraction(config_with_name)
assert interaction_with_name.name == "custom_gsm8k"
# Test with default name when not provided in config
config_without_name = {}
interaction_without_name = Gsm8kInteraction(config_without_name)
assert interaction_without_name.name == "interaction_agent" # Default from BaseInteraction
# Test that name is accessible as attribute
assert hasattr(self.interaction, "name")
assert self.interaction.name == "gsm8k"
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tempfile
import pytest
from omegaconf import OmegaConf
from verl.interactions.base import BaseInteraction
from verl.interactions.gsm8k_interaction import Gsm8kInteraction
from verl.interactions.utils.interaction_registry import (
get_interaction_class,
initialize_interactions_from_config,
)
class TestInteractionRegistry:
def test_get_interaction_class(self):
"""Test getting interaction class by name."""
# Test getting base interaction class
base_cls = get_interaction_class("verl.interactions.base.BaseInteraction")
assert base_cls == BaseInteraction
# Test getting gsm8k interaction class
gsm8k_cls = get_interaction_class("verl.interactions.gsm8k_interaction.Gsm8kInteraction")
assert gsm8k_cls == Gsm8kInteraction
def test_initialize_single_interaction_from_config(self):
"""Test initializing single interaction from config."""
# Create temporary config file
config_content = {
"interaction": [
{
"name": "test_gsm8k",
"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction",
"config": {},
}
]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name
try:
interaction_map = initialize_interactions_from_config(temp_config_path)
# Check that interaction was created
assert len(interaction_map) == 1
assert "test_gsm8k" in interaction_map
assert isinstance(interaction_map["test_gsm8k"], Gsm8kInteraction)
assert interaction_map["test_gsm8k"].name == "test_gsm8k"
finally:
os.unlink(temp_config_path)
def test_initialize_multiple_interactions_from_config(self):
"""Test initializing multiple interactions from config."""
config_content = {
"interaction": [
{
"name": "gsm8k_solver",
"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction",
"config": {},
},
{
"name": "base_agent",
"class_name": "verl.interactions.base.BaseInteraction",
"config": {"custom_param": "test_value"},
},
]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name
try:
interaction_map = initialize_interactions_from_config(temp_config_path)
# Check that both interactions were created
assert len(interaction_map) == 2
assert "gsm8k_solver" in interaction_map
assert "base_agent" in interaction_map
# Check types
assert isinstance(interaction_map["gsm8k_solver"], Gsm8kInteraction)
assert isinstance(interaction_map["base_agent"], BaseInteraction)
# Check names were injected
assert interaction_map["gsm8k_solver"].name == "gsm8k_solver"
assert interaction_map["base_agent"].name == "base_agent"
# Check custom config was passed
assert interaction_map["base_agent"].config.get("custom_param") == "test_value"
finally:
os.unlink(temp_config_path)
def test_initialize_interaction_without_explicit_name(self):
"""Test that interaction name is derived from class name when not specified."""
config_content = {
"interaction": [{"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}}]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name
try:
interaction_map = initialize_interactions_from_config(temp_config_path)
# Check that interaction name was derived from class name
assert len(interaction_map) == 1
assert "gsm8k" in interaction_map # Should be "gsm8k" after removing "interaction" suffix
assert isinstance(interaction_map["gsm8k"], Gsm8kInteraction)
assert interaction_map["gsm8k"].name == "gsm8k"
finally:
os.unlink(temp_config_path)
def test_initialize_empty_config(self):
"""Test initializing from empty config."""
config_content = {"interaction": []}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name
try:
interaction_map = initialize_interactions_from_config(temp_config_path)
assert len(interaction_map) == 0
finally:
os.unlink(temp_config_path)
def test_invalid_class_name(self):
"""Test handling of invalid class name."""
config_content = {
"interaction": [{"name": "invalid", "class_name": "invalid.module.InvalidClass", "config": {}}]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name
try:
with pytest.raises(ModuleNotFoundError):
initialize_interactions_from_config(temp_config_path)
finally:
os.unlink(temp_config_path)
def test_duplicate_interaction_names(self):
"""Test handling of duplicate interaction names."""
config_content = {
"interaction": [
{"name": "duplicate", "class_name": "verl.interactions.base.BaseInteraction", "config": {}},
{
"name": "duplicate",
"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction",
"config": {},
},
]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name
try:
with pytest.raises(ValueError, match="Duplicate interaction name 'duplicate' found"):
initialize_interactions_from_config(temp_config_path)
finally:
os.unlink(temp_config_path)
def test_auto_name_generation_edge_cases(self):
"""Test automatic name generation for various class name patterns."""
config_content = {
"interaction": [
{"class_name": "verl.interactions.base.BaseInteraction", "config": {}},
{"class_name": "verl.interactions.gsm8k_interaction.Gsm8kInteraction", "config": {}},
]
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
OmegaConf.save(config_content, f.name)
temp_config_path = f.name
try:
interaction_map = initialize_interactions_from_config(temp_config_path)
# Check that names were generated correctly
assert len(interaction_map) == 2
assert "base" in interaction_map # BaseInteraction -> base
assert "gsm8k" in interaction_map # Gsm8kInteraction -> gsm8k
finally:
os.unlink(temp_config_path)
#!/bin/bash
if [ "$#" -ne 1 ]; then
echo "Usage: $0 YOUR_GITHUB_TOKEN"
echo "Please provide exactly one input argument for your github token."
exit 1
fi
# Set your GitHub repository details
OWNER="volcengine"
REPO="verl"
TOKEN=$1
# API URL for workflow runs
API_URL="https://api.github.com/repos/$OWNER/$REPO/actions/runs?status=queued"
# Check required commands
command -v jq >/dev/null 2>&1 || { echo "jq is required but not installed. Aborting."; exit 1; }
# Get queued workflow runs
response=$(curl -s -H "Authorization: token $TOKEN" -H "Accept: application/vnd.github.v3+json" "$API_URL")
# Run this for debugging
# echo $response
# Extract run IDs
queued_run_ids=$(echo "$response" | jq -r '.workflow_runs[] | .id')
if [ -z "$queued_run_ids" ]; then
echo "No queued workflow runs found."
exit 0
fi
# Cancel each queued run
for run_id in $queued_run_ids; do
echo "Cancelling run $run_id"
cancel_url="https://api.github.com/repos/$OWNER/$REPO/actions/runs/$run_id/cancel"
curl -s -X POST -H "Authorization: token $TOKEN" -H "Accept: application/vnd.github.v3+json" "$cancel_url"
done
echo "Cancelled all queued workflow runs."
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from transformers import (
AutoModelForCausalLM,
AutoModelForTokenClassification,
GemmaConfig,
LlamaConfig,
MistralConfig,
Qwen2Config,
)
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.torch_functional import log_probs_from_logits_all_rmpad, masked_mean
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = [
LlamaConfig(num_hidden_layers=1),
MistralConfig(num_hidden_layers=1),
GemmaConfig(num_hidden_layers=1),
Qwen2Config(num_hidden_layers=1),
]
def test_hf_casual_models():
batch_size = 4
seqlen = 128
response_length = 127
for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model.to(device="cuda")
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5,
)
position_ids = compute_position_id_with_mask(
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_rmpad = model(
input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False
).logits # (1, total_nnz, vocab_size)
origin_logits = model(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
).logits
origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)
logits_rmpad = logits_rmpad.squeeze(0)
log_probs = log_probs_from_logits_all_rmpad(
input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length,
) # (batch, seqlen)
origin_log_probs = log_probs_from_logits_all_rmpad(
input_ids_rmpad=input_ids_rmpad,
logits_rmpad=origin_logits_rmpad,
indices=origin_logits_indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length,
) # (batch, seqlen)
torch.testing.assert_close(
masked_mean(log_probs, attention_mask[:, -response_length - 1 : -1]),
masked_mean(origin_log_probs, attention_mask[:, -response_length - 1 : -1]),
atol=1e-2,
rtol=1e-5,
)
print("Check pass")
def test_hf_value_models():
batch_size = 4
seqlen = 128
for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
config.num_labels = 1
config.classifier_dropout = 0
config.hidden_dropout = 0
with torch.device("cuda"):
model = AutoModelForTokenClassification.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model.to(device="cuda")
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5,
)
position_ids = compute_position_id_with_mask(
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
origin_logits = model(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
).logits
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
rmpad_logits = model(
input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False
).logits # (1, total_nnz, 1)
rmpad_logits = rmpad_logits.squeeze(0)
pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)
torch.testing.assert_close(
masked_mean(pad_logits, attention_mask[:, :, None]),
masked_mean(origin_logits, attention_mask[:, :, None]),
atol=1e-2,
rtol=1e-5,
)
print("Value model check pass")
if __name__ == "__main__":
test_hf_casual_models()
test_hf_value_models()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
from dataclasses import dataclass
import pytest
import torch
import torch.distributed
from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input
from torch.distributed import init_device_mesh
from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config
from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.protocol import DataProto
from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.ulysses import (
gather_outputs_and_unpad,
get_ulysses_sequence_parallel_world_size,
set_ulysses_sequence_parallel_group,
ulysses_pad_and_slice_inputs,
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
# TODO(sgm): add more models for test
# we only need one scale for each model
@dataclass
class SequenceParallelConfig:
config: PretrainedConfig
sp_size: int
is_valid: bool
def test_configs():
return [
SequenceParallelConfig(
LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),
sp_size=4,
is_valid=True,
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),
sp_size=8,
is_valid=False,
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True
),
SequenceParallelConfig(
Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True
),
]
def sync_model_parameters_global(layer):
# synchronize weights
for p in layer.parameters():
torch.distributed.broadcast(tensor=p.data, src=0)
@pytest.mark.parametrize("test_config", test_configs())
def test_hf_casual_fwd_bwd(test_config):
if not torch.distributed.is_initialized():
initialize_global_process_group()
context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError)
with context:
world_size = torch.distributed.get_world_size()
_hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size)
# TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort`
# torch.distributed.destroy_process_group()
def _hf_casual_fwd(config, sp_size, dp_size):
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
ulysses_device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")
)
sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)
batch_size = 1
seqlen = 128
# response_length = 127
# patch before load
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
apply_monkey_patch(model, sp_size)
model = model.to(device="cuda")
sync_model_parameters_global(model)
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8
)
position_ids = compute_position_id_with_mask(
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
model_inputs = {
"input_ids": input_ids.cuda(),
"attention_mask": attention_mask.cuda(),
"position_ids": position_ids.int().cuda(),
}
model_inputs = DataProto.from_dict(model_inputs)
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch["input_ids"]
attention_mask = model_inputs.batch["attention_mask"]
position_ids = model_inputs.batch["position_ids"]
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()
)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(
input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False
).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
logits_rmpad_local = model(
input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False
).logits # (1, total_nnz, vocab_size)
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
def _hf_casual_fwd_bwd(config, sp_size, dp_size):
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
ulysses_device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")
)
sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)
batch_size = 1
seqlen = 128
# response_length = 127
# patch before load
with torch.device("cuda"):
model = AutoModelForCausalLM.from_config(
config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
apply_monkey_patch(model, sp_size)
model = model.to(device="cuda")
sync_model_parameters_global(model)
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
attention_mask = create_random_mask(
input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8
)
position_ids = compute_position_id_with_mask(
attention_mask
) # TODO(sgm): we can construct the position_ids_rmpad here
model_inputs = {
"input_ids": input_ids.cuda(),
"attention_mask": attention_mask.cuda(),
"position_ids": position_ids.int().cuda(),
}
model_inputs = DataProto.from_dict(model_inputs)
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch["input_ids"]
attention_mask = model_inputs.batch["attention_mask"]
position_ids = model_inputs.batch["position_ids"]
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()
)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(
input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False
).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
input_ids_full = copy.deepcopy(input_ids_rmpad)
position_ids_full = copy.deepcopy(position_ids_rmpad)
model_no_sp = copy.deepcopy(model)
logits_rmpad_local = model_no_sp(
input_ids_full, position_ids=position_ids_full, use_cache=False
).logits # (1, total_nnz, vocab_size)
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
mean_full.backward()
mean_local.backward()
# 3. check the gradients
grad = model.model.layers[0].self_attn.q_proj.weight.grad
grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)
if __name__ == "__main__":
pytest.main([__file__, "-svv"])
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import verl.single_controller.base.decorator as decorator_module
from verl.single_controller.base.decorator import (
DISPATCH_MODE_FN_REGISTRY,
Dispatch,
_check_dispatch_mode,
get_predefined_dispatch_fn,
register_dispatch_mode,
update_dispatch_mode,
)
@pytest.fixture
def reset_dispatch_registry():
# Store original state
original_registry = DISPATCH_MODE_FN_REGISTRY.copy()
yield
# Reset registry after test
decorator_module.DISPATCH_MODE_FN_REGISTRY.clear()
decorator_module.DISPATCH_MODE_FN_REGISTRY.update(original_registry)
def test_register_new_dispatch_mode(reset_dispatch_registry):
# Test registration
def dummy_dispatch(worker_group, *args, **kwargs):
return args, kwargs
def dummy_collect(worker_group, output):
return output
register_dispatch_mode("TEST_MODE", dummy_dispatch, dummy_collect)
# Verify enum extension
_check_dispatch_mode(Dispatch.TEST_MODE)
# Verify registry update
assert get_predefined_dispatch_fn(Dispatch.TEST_MODE) == {
"dispatch_fn": dummy_dispatch,
"collect_fn": dummy_collect,
}
# Clean up
Dispatch.remove("TEST_MODE")
def test_update_existing_dispatch_mode(reset_dispatch_registry):
# Store original implementation
original_mode = Dispatch.ONE_TO_ALL
# New implementations
def new_dispatch(worker_group, *args, **kwargs):
return args, kwargs
def new_collect(worker_group, output):
return output
# Test update=
update_dispatch_mode(original_mode, new_dispatch, new_collect)
# Verify update
assert get_predefined_dispatch_fn(original_mode)["dispatch_fn"] == new_dispatch
assert get_predefined_dispatch_fn(original_mode)["collect_fn"] == new_collect
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import ray
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@ray.remote
class TestActor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def foo(self, wait_time):
time.sleep(wait_time)
sys.exit(1)
if __name__ == "__main__":
wait_time = int(os.getenv("WAIT_TIME", "10"))
ray.init()
# test single-node-no-partition
print("test single-node-no-partition")
resource_pool = RayResourcePool([2], use_gpu=False)
class_with_args = RayClassWithInitArgs(cls=TestActor)
print("create worker group")
wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="test")
wg.start_worker_aliveness_check(1)
time.sleep(1)
print(time.time(), "start foo")
_ = wg.foo(wait_time)
print("foo started")
print(
time.time(),
f"wait 6x wait time {wait_time * 6} to let signal returned to process but still not exceed process wait time",
)
time.sleep(wait_time * 6)
ray.shutdown()
# Detached Worker
## How to run (Only on a single node)
- Start a local ray cluster:
```bash
ray start --head --port=6379
```
- Run the server
```bash
python3 server.py
```
- On another terminal, Run the client
```bash
python3 client.py
```
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In client, we can get the server handler and send RPC request
"""
import ray
import torch
from server import Trainer
from tensordict import TensorDict
from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
def compute_position_id_with_mask(mask):
return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None)
if __name__ == "__main__":
ray.init(address="auto", namespace="verl")
# get the worker group using names
worker_names = ["trainerTrainer_0:0", "trainerTrainer_0:1"]
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup.from_detached(
worker_names=worker_names, ray_cls_with_init=cls_with_init_args
)
batch_size = 16
sequence_length = 1024
# give Trainer some data to train
input_ids = torch.randint(low=0, high=256, size=(batch_size, sequence_length), dtype=torch.int64, device="cuda")
attention_mask = torch.ones_like(input_ids)
position_ids = compute_position_id_with_mask(attention_mask)
data = DataProto(
batch=TensorDict(
{"input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids},
batch_size=batch_size,
),
meta_info={},
)
output = worker_group.train_model(data)
print(output)
#!/bin/bash
ray start --head --port=6379
python3 server.py
python3 client.py
ray stop --force
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Server starts a Trainer. Client sends data to the server to train.
"""
import os
os.environ["MEGATRON_USE_CUDA_TIMER"] = "0"
os.environ["MEGATRON_START_PROCESS_TIMER"] = "False"
os.environ["NCCL_DEBUG"] = "WARN"
import ray
import torch
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core.models.gpt.gpt_model import ModelType
from omegaconf import OmegaConf
from tensordict import TensorDict
from torch import nn
from transformers import LlamaConfig
from verl import DataProto
from verl.models.llama.megatron import ParallelLlamaForCausalLMRmPadPP
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.megatron.worker import MegatronWorker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool
from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
from verl.utils.megatron.optimizer import get_megatron_optimizer
from verl.utils.megatron_utils import get_model, init_megatron_optim_config, mcore_model_parallel_config
@ray.remote
class Trainer(MegatronWorker):
def __init__(self):
super().__init__()
if not torch.distributed.is_initialized():
rank = int(os.environ["LOCAL_RANK"])
torch.distributed.init_process_group(backend="nccl")
torch.cuda.set_device(rank)
mpu.initialize_model_parallel(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=1,
virtual_pipeline_model_parallel_size=None,
pipeline_model_parallel_split_rank=None,
use_sharp=False,
context_parallel_size=1,
expert_model_parallel_size=1,
nccl_communicator_config_path=None,
)
tensor_parallel.model_parallel_cuda_manual_seed(10)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
actor_model_config = LlamaConfig(
vocab_size=256,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=16,
num_key_value_heads=16,
)
megatron_config = mcore_model_parallel_config(sequence_parallel=True, params_dtype=torch.bfloat16)
self.megatron_config = megatron_config
def megatron_actor_model_provider(pre_process, post_process):
# vpp is not supported yet because it will hang for some reason. Need debugging
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model = ParallelLlamaForCausalLMRmPadPP(
config=actor_model_config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process,
)
parallel_model.cuda()
return parallel_model
actor_module = get_model(
model_provider_func=megatron_actor_model_provider,
model_type=ModelType.encoder_or_decoder,
wrap_with_ddp=True,
)
actor_module = nn.ModuleList(actor_module)
optim_config = OmegaConf.create({"lr": 1e-6, "clip_grad": 1.0})
optim_config = init_megatron_optim_config(optim_config)
self.optimizer_config = optim_config
actor_optimizer = get_megatron_optimizer(model=actor_module, config=optim_config)
self.model = actor_module[0]
self.optimizer = actor_optimizer
@register(dispatch_mode=Dispatch.MEGATRON_COMPUTE_PROTO)
def train_model(self, data: DataProto) -> DataProto:
input_ids = data.batch["input_ids"]
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
self.optimizer.zero_grad()
self.model.zero_grad_buffer(
zero_buffer=(not self.optimizer_config.use_distributed_optimizer)
) # use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids).logits
output.mean().backward()
update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(
self.megatron_config, self.megatron_config.timers
)
return DataProto(batch=TensorDict({"loss": output.detach()}, batch_size=output.shape[0]))
if __name__ == "__main__":
ray.init(address="auto", namespace="verl")
resource_pool = RayResourcePool(process_on_nodes=[2], detached=True)
cls_with_init_args = RayClassWithInitArgs(cls=Trainer)
worker_group = NVMegatronRayWorkerGroup(
resource_pool=resource_pool,
ray_cls_with_init=cls_with_init_args,
name_prefix="trainer",
detached=True,
)
worker_group.init_model()
worker_names = worker_group.worker_names
print(worker_names)
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import ray
import torch
from verl import DataProto
from verl.protocol import DataProtoConfig
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
# or set env var VERL_AUTO_PADDING = "1" / "true"
DataProtoConfig.auto_padding = True
@ray.remote
class Actor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def add(self, data: DataProto):
data.batch["a"] += self.rank
return data
def test_auto_padding():
ray.init(num_cpus=100)
chunk_size = 4
actor_cls = RayClassWithInitArgs(cls=Actor)
resource_pool = RayResourcePool(process_on_nodes=[chunk_size], use_gpu=False)
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
# test locally first
for test_size in range(4, 20):
local_data = DataProto.from_dict({"a": torch.zeros(test_size)}, {"na": np.zeros(test_size, dtype=object)})
# print(f"before padding, local_data = {local_data}")
padding_size = (chunk_size - (test_size % chunk_size)) if (test_size % chunk_size > 0) else 0
local_data.padding(padding_size)
# print(f"after padding, local_data = {local_data}")
assert len(local_data) == len(local_data) + len(local_data) % chunk_size, (
f"expecting padded length to be {len(local_data) + len(local_data) % chunk_size}, but got {len(local_data)}"
)
chunked = local_data.chunk(chunk_size)
assert len(chunked) == chunk_size, f"during test_size = {test_size}, expecting {chunk_size}, got {chunked}"
for dp in chunked:
assert len(dp) == test_size // chunk_size + bool(test_size % chunk_size), (
f"test size = {test_size}, expecting dp to be length of "
f"{test_size // chunk_size + bool(test_size % chunk_size)}, but got {len(dp)}: {dp} {chunked}"
)
# test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO
data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)})
output = actor_wg.add(data)
print(output.batch["a"])
assert len(output) == 10, "Failed in args split and padding."
data = DataProto.from_dict({"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)})
output = actor_wg.add(data=data)
print(output.batch["a"])
assert len(output) == 10, "Failed in kwargs split and padding."
data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)})
output = actor_wg.add(data)
print(output.batch["a"])
assert len(output) == 1, "Failed in args split and padding."
data = DataProto.from_dict({"a": torch.zeros(1)}, {"na": np.array([str(i) for i in range(1)], dtype=object)})
output = actor_wg.add(data=data)
print(output.batch["a"])
assert len(output) == 1, "Failed in kwargs split and padding."
data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)})
output = actor_wg.add(data)
print(output.batch["a"])
assert len(output) == 8, "Failed in args split and padding."
data = DataProto.from_dict({"a": torch.zeros(8)}, {"na": np.array([str(i) for i in range(8)], dtype=object)})
output = actor_wg.add(data=data)
print(output.batch["a"])
assert len(output) == 8, "Failed in kwargs split and padding."
# test data proto specific config
DataProtoConfig.auto_padding = False
data = DataProto.from_dict(
{"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True
)
output = actor_wg.add(data)
print(output.batch["a"])
assert len(output) == 10, "Failed in args split and padding."
data = DataProto.from_dict(
{"a": torch.zeros(10)}, {"na": np.array([str(i) for i in range(10)], dtype=object)}, auto_padding=True
)
output = actor_wg.add(data=data)
print(output.batch["a"])
assert len(output) == 10, "Failed in kwargs split and padding."
data = DataProto.from_single_dict(
{"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True
)
output = actor_wg.add(data)
print(output.batch["a"])
assert len(output) == 1, "Failed in args split and padding."
data = DataProto.from_single_dict(
{"a": torch.zeros(1), "na": np.array([str(i) for i in range(1)], dtype=object)}, auto_padding=True
)
output = actor_wg.add(data=data)
print(output.batch["a"])
assert len(output) == 1, "Failed in kwargs split and padding."
data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)})
output = actor_wg.add(data)
print(output.batch["a"])
assert len(output) == 8, "Failed in args split and padding."
data = DataProto.from_single_dict({"a": torch.zeros(8), "na": np.array([str(i) for i in range(8)], dtype=object)})
output = actor_wg.add(data=data)
print(output.batch["a"])
assert len(output) == 8, "Failed in kwargs split and padding."
ray.shutdown()
if __name__ == "__main__":
test_auto_padding()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray.base import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
create_colocated_worker_cls,
)
@ray.remote
class Actor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def add(self, data: DataProto):
data.batch["a"] += self.rank
return data
@ray.remote
class Critic(Worker):
def __init__(self, config) -> None:
super().__init__()
self.config = config
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
async def sub(self, data: DataProto):
data.batch["a"] -= self.config["b"]
return data
def test_colocated_workers():
ray.init()
import torch
data = DataProto.from_dict({"a": torch.zeros(10)})
# create separate workers on the same resource pool
actor_cls = RayClassWithInitArgs(cls=Actor)
critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
resource_pool = RayResourcePool(process_on_nodes=[2])
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)
expected_actor_output = actor_wg.add(data)
expected_critic_output = critic_wg.sub(data)
# create colocated workers
cls_dict = {"actor": actor_cls, "critic": critic_cls}
ray_cls_with_init = create_colocated_worker_cls(cls_dict)
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
colocated_actor_wg = spawn_wg["actor"]
colocated_critic_wg = spawn_wg["critic"]
actor_output = colocated_actor_wg.add(data)
critic_output = colocated_critic_wg.sub(data)
torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray.base import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
create_colocated_worker_cls_fused,
)
@ray.remote
class Actor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def add(self, data: DataProto):
data.batch["a"] += self.rank
return data
@ray.remote
class Critic(Worker):
def __init__(self, config) -> None:
super().__init__()
self.config = config
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def sub(self, data: DataProto):
data.batch["a"] -= self.config["b"]
return data
def test_colocated_workers_fused():
ray.init()
import torch
data = DataProto.from_dict({"a": torch.zeros(10)})
# create separate workers on the same resource pool
actor_cls = RayClassWithInitArgs(cls=Actor)
critic_cls = RayClassWithInitArgs(cls=Critic, config={"b": 10})
resource_pool = RayResourcePool(process_on_nodes=[2])
actor_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=actor_cls)
critic_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=critic_cls)
expected_actor_output = actor_wg.add(data)
expected_critic_output = critic_wg.sub(data)
# create colocated workers
cls_dict = {"actor": actor_cls, "critic": critic_cls}
ray_cls_with_init = create_colocated_worker_cls_fused(cls_dict)
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init)
spawn_wg = wg_dict.spawn(prefix_set=cls_dict.keys())
colocated_actor_wg = spawn_wg["actor"]
colocated_critic_wg = spawn_wg["critic"]
actor_output = colocated_actor_wg.add(data)
critic_output = colocated_critic_wg.sub(data)
torch.testing.assert_close(expected_actor_output.batch, actor_output.batch, atol=0, rtol=0)
torch.testing.assert_close(expected_critic_output.batch, critic_output.batch, atol=0, rtol=0)
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In this test, we instantiate a data parallel worker with 8 GPUs
"""
import ray
import tensordict
import torch
from codetiming import Timer
from torch import distributed as dist
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
from verl.utils.ray_utils import parallel_put
@ray.remote
class DummyWorker(Worker):
def __init__(self):
super().__init__()
dist.init_process_group()
@register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False)
def do_nothing(self, data):
for key in data.batch.keys():
data.batch[key] += 1
if tensordict.__version__ >= "0.5.0":
data.batch = data.batch.consolidate()
return data
def test_data_transfer():
ray.init()
# construct resource pool
resource_pool = RayResourcePool([8])
cls_with_init = RayClassWithInitArgs(cls=DummyWorker)
# construct worker group
wg = RayWorkerGroup(resource_pool, cls_with_init)
# this is real dataset size
batch_size = 4096
seqlen = 32768
data_dict = {}
for i in range(2):
data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen))
data = DataProto.from_dict(tensors=data_dict)
print(data)
# we manually split data here and send to each worker
data_list = data.chunk(wg.world_size)
for i in range(wg.world_size):
# consolidate is necessary
if tensordict.__version__ >= "0.5.0":
data_list[i].batch = data_list[i].batch.consolidate()
with Timer(name="ray.pickle", initial_text=True):
for i in range(wg.world_size):
ray.cloudpickle.pickle.dumps(data_list[i])
with Timer(name="raw.pickle", initial_text=True):
import pickle
for i in range(wg.world_size):
pickle.dumps(data_list[i])
# we put in advance
with Timer(name="put", initial_text=True):
# takes around 40 seconds
data_list_ref = parallel_put(data_list)
# for i in range(wg.world_size):
# data_list[i] = ray.put(data_list[i])
with Timer(name="launch", initial_text=True):
output_ref = wg.do_nothing(data_list_ref)
with Timer(name="get", initial_text=True):
# takes around 40 seconds
output_lst = ray.get(output_ref)
for input_data, output_data in zip(data_list, output_lst, strict=True):
for key in input_data.batch.keys():
assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), (
input_data.batch[key],
output_data.batch[key],
key,
)
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
import pytest
import ray
import torch
from tensordict import TensorDict
from verl.protocol import DataProto, DataProtoFuture
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
# Pytest fixture for Ray setup/teardown
@pytest.fixture
def ray_init_shutdown():
ray.init(num_cpus=100)
yield
ray.shutdown()
# Define a simple worker for testing
@ray.remote
class DecoratorTestWorker(Worker):
def __init__(self, initial_value=0):
super().__init__()
self.value = initial_value
# Simulate some setup if needed
time.sleep(0.1) # Ensure worker init completes
# Test method for synchronous DP compute (default behavior)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def dp_compute(self, data: DataProto) -> DataProto:
time.sleep(0.1) # Simulate work
rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype)
data.batch["output"] = data.batch["input"] + self.value + rank_value
return data
# Test async def method with DP compute (default behavior)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO, blocking=False)
async def async_dp_compute(self, data: DataProto) -> DataProto:
# Simulate async work
await asyncio.sleep(0.1) # Simulate async work
rank_value = torch.tensor(self.rank, device=data.batch["input"].device, dtype=data.batch["input"].dtype)
data.batch["output_async"] = data.batch["input"] * 2 + self.value + rank_value
return data
# Test function for synchronous DP compute
def test_decorator_dp_compute(ray_init_shutdown):
"""
Tests the default behavior of a synchronous decorated method with DP_COMPUTE_PROTO.
Verifies the result correctness.
"""
num_workers = 2
resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1) # Use CPU for simplicity
cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=10)
worker_group = RayWorkerGroup(
resource_pool, cls_with_args, name_prefix=f"decorator_test_sync_dp_{int(time.time())}"
)
# Prepare input data (size 4, for 2 workers)
input_tensor = torch.arange(4, dtype=torch.float32)
data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4]))
# Call the decorated method
output = worker_group.dp_compute(data)
# Assert the result correctness
assert isinstance(output, DataProto), "Expected DataProto result"
assert "output" in output.batch.keys()
assert len(output) == len(data), "Output length should match input length"
# Expected output calculation for DP_COMPUTE_PROTO with 2 workers
# Worker 0 gets data[0:2], Worker 1 gets data[2:4]
# Worker 0 adds initial_value(10) + rank(0) = 10
# Worker 1 adds initial_value(10) + rank(1) = 11
expected_output_part1 = torch.tensor([0, 1], dtype=torch.float32) + 10 + 0
expected_output_part2 = torch.tensor([2, 3], dtype=torch.float32) + 10 + 1
expected_output = torch.cat([expected_output_part1, expected_output_part2])
torch.testing.assert_close(output.batch["output"], expected_output, msg="Sync DP compute output data mismatch")
# Test function for async def method with DP compute
def test_decorator_async_function(ray_init_shutdown):
"""
Tests the decorator with an `async def` method using DP_COMPUTE_PROTO.
Verifies that the call returns a future and the result is correct after .get().
"""
num_workers = 2
resource_pool = RayResourcePool([num_workers], use_gpu=False, max_colocate_count=1)
cls_with_args = RayClassWithInitArgs(cls=DecoratorTestWorker, initial_value=5)
worker_group = RayWorkerGroup(
resource_pool, cls_with_args, name_prefix=f"decorator_test_async_dp_{int(time.time())}"
)
# Prepare input data (size 4, for 2 workers)
input_tensor = torch.arange(4, dtype=torch.float32)
data = DataProto(batch=TensorDict({"input": input_tensor}, batch_size=[4]))
# Call the async decorated method - this should return a future
future_output: DataProtoFuture = worker_group.async_dp_compute(data)
# Assert that the call returned a future
assert isinstance(future_output, DataProtoFuture), "Expected DataProtoFuture for async def call"
# Get the result (this should block)
result_data = future_output.get()
# Assert the result correctness
assert isinstance(result_data, DataProto)
assert "output_async" in result_data.batch.keys()
assert len(result_data) == len(data), "Output length should match input length"
# Expected output calculation for DP_COMPUTE_PROTO with 2 workers
# Worker 0 gets data[0:2], Worker 1 gets data[2:4]
# Worker 0 calculates: input * 2 + initial_value(5) + rank(0)
# Worker 1 calculates: input * 2 + initial_value(5) + rank(1)
expected_output_part1 = (torch.tensor([0, 1], dtype=torch.float32) * 2) + 5 + 0
expected_output_part2 = (torch.tensor([2, 3], dtype=torch.float32) * 2) + 5 + 1
expected_output = torch.cat([expected_output_part1, expected_output_part2])
torch.testing.assert_close(
result_data.batch["output_async"], expected_output, msg="Async DP compute output data mismatch"
)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ray
import torch
from tensordict import TensorDict
from verl import DataProto
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray import RayWorkerGroup
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool
os.environ["RAY_DEDUP_LOGS"] = "0"
os.environ["NCCL_DEBUG"] = "WARN"
@ray.remote
class ModelActor(Worker):
def __init__(self):
pass
class HackSelf:
def __init__(self):
pass
def get_aux_metrics(self, test_proto):
sequence_ids = test_proto.batch["sequence_ids"]
decode_count = []
for i in range(sequence_ids.size(0)):
decode_count.append(len(sequence_ids[i].tolist()))
ret_proto = DataProto(
batch=TensorDict(
{"sequence_ids": sequence_ids, "decode_count": torch.tensor(decode_count)}, batch_size=sequence_ids.size(0)
)
)
return ret_proto
def test():
# construct model
ray.init()
# create 2 workers, each hold a GPU
resource_pool = RayResourcePool([2], use_gpu=True, name_prefix="a")
class_with_args = RayClassWithInitArgs(cls=ModelActor)
shard_wg = RayWorkerGroup(resource_pool, class_with_args)
test_bs = 8
test_proto = DataProto(
TensorDict(
{
"sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
},
batch_size=test_bs,
),
meta_info={"query_length": 1536},
)
# Sharding among different ranks
ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)
# compare execute on driver
hs = HackSelf()
ret_proto2 = get_aux_metrics(hs, test_proto)
torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])
ray.shutdown()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray.base import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
create_colocated_worker_raw_cls,
)
@ray.remote
class Actor(Worker):
def __init__(self) -> None:
super().__init__()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def add(self, x):
x += self.rank
return x
@ray.remote
class Critic(Worker):
def __init__(self, val) -> None:
super().__init__()
self.val = val
@register(dispatch_mode=Dispatch.ALL_TO_ALL)
def sub(self, x):
x -= self.val
return x
actor_cls = RayClassWithInitArgs(cls=Actor)
critic_cls = RayClassWithInitArgs(cls=Critic, val=10)
cls_dict = {"actor": actor_cls, "critic": critic_cls}
FusedBaseClass = create_colocated_worker_raw_cls(cls_dict)
@ray.remote
class HybridWorker(FusedBaseClass):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def foo(self, x):
return self.critic.sub(self.actor.add(x))
def test_fused_workers():
ray.init(num_cpus=100)
# create separate workers on the same resource pool
process_on_nodes = [2]
resource_pool = RayResourcePool(process_on_nodes=process_on_nodes, use_gpu=False)
# create colocated workers
hybrid_cls_with_init = RayClassWithInitArgs(cls=HybridWorker)
hybrid_cls_with_init.fused_worker_used = True
fused_wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=hybrid_cls_with_init)
fused_wg.fuse(cls_dict.keys())
x = fused_wg.actor.add(0.1)
print(x)
y = fused_wg.critic.sub(x)
print(y)
z = fused_wg.foo(0.1)
print(z)
for i, j in zip(y, z, strict=True):
assert i == j
ray.shutdown()
if __name__ == "__main__":
test_fused_workers()
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import ray
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup, merge_resource_pool
@ray.remote
class TestActor(Worker):
# TODO: pass *args and **kwargs is bug prone and not very convincing
def __init__(self, cuda_visible_devices=None) -> None:
super().__init__(cuda_visible_devices)
def get_node_id(self):
return ray.get_runtime_context().get_node_id()
def test():
ray.init()
# test single-node-no-partition
print("test single-node-no-partition")
resource_pool = RayResourcePool([8], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=TestActor)
print("create actor worker group")
actor_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_actor")
print("create critic worker group")
critic_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="hight_level_api_critic")
print("create rm worker group")
rm_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_rm")
print("create ref worker group")
ref_wg = RayWorkerGroup(resource_pool, class_with_args, name_prefix="high_level_api_ref")
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
del actor_wg
del critic_wg
del rm_wg
del ref_wg
[ray.util.remove_placement_group(pg) for pg in resource_pool.get_placement_groups()]
print("wait 5s to remove placemeng_group")
time.sleep(5)
# test single-node-multi-partition
print("test single-node-multi-partition")
rm_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="rm")
ref_resource_pool = RayResourcePool([4], use_gpu=True, name_prefix="ref")
total_resource_pool = merge_resource_pool(rm_resource_pool, ref_resource_pool)
assert rm_resource_pool.world_size == 4
assert ref_resource_pool.world_size == 4
assert total_resource_pool.world_size == 8
actor_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_actor")
critic_wg = RayWorkerGroup(total_resource_pool, class_with_args, name_prefix="high_level_api_critic")
rm_wg = RayWorkerGroup(rm_resource_pool, class_with_args, name_prefix="high_level_api_rm")
ref_wg = RayWorkerGroup(ref_resource_pool, class_with_args, name_prefix="high_level_api_ref")
assert actor_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert critic_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(8)]
assert rm_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4)]
assert ref_wg.execute_all_sync("get_cuda_visible_devices") == [str(i) for i in range(4, 8)]
ray.shutdown()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment