Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
verl_grpo
Commits
7f6cc211
Commit
7f6cc211
authored
Aug 05, 2025
by
jerrrrry
Browse files
Initial commit
parents
Pipeline
#2874
failed with stages
in 0 seconds
Changes
421
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2298 additions
and
0 deletions
+2298
-0
tests/interactions/test_gsm8k_interaction.py
tests/interactions/test_gsm8k_interaction.py
+421
-0
tests/interactions/test_interaction_registry.py
tests/interactions/test_interaction_registry.py
+206
-0
tests/kill_github_tests.sh
tests/kill_github_tests.sh
+41
-0
tests/models/test_transformer.py
tests/models/test_transformer.py
+166
-0
tests/models/test_transformers_ulysses.py
tests/models/test_transformers_ulysses.py
+262
-0
tests/single_controller/__init__.py
tests/single_controller/__init__.py
+13
-0
tests/single_controller/base/test_decorator.py
tests/single_controller/base/test_decorator.py
+76
-0
tests/single_controller/check_worker_alive/main.py
tests/single_controller/check_worker_alive/main.py
+64
-0
tests/single_controller/detached_worker/README.md
tests/single_controller/detached_worker/README.md
+14
-0
tests/single_controller/detached_worker/client.py
tests/single_controller/detached_worker/client.py
+59
-0
tests/single_controller/detached_worker/run.sh
tests/single_controller/detached_worker/run.sh
+6
-0
tests/single_controller/detached_worker/server.py
tests/single_controller/detached_worker/server.py
+145
-0
tests/single_controller/test_auto_padding_on_cpu.py
tests/single_controller/test_auto_padding_on_cpu.py
+152
-0
tests/single_controller/test_colocated_workers.py
tests/single_controller/test_colocated_workers.py
+83
-0
tests/single_controller/test_colocated_workers_fused.py
tests/single_controller/test_colocated_workers_fused.py
+83
-0
tests/single_controller/test_data_transfer.py
tests/single_controller/test_data_transfer.py
+107
-0
tests/single_controller/test_decorator_on_cpu.py
tests/single_controller/test_decorator_on_cpu.py
+141
-0
tests/single_controller/test_driverfunc_to_worker.py
tests/single_controller/test_driverfunc_to_worker.py
+84
-0
tests/single_controller/test_fused_workers_on_cpu.py
tests/single_controller/test_fused_workers_on_cpu.py
+90
-0
tests/single_controller/test_high_level_scheduling_api.py
tests/single_controller/test_high_level_scheduling_api.py
+85
-0
No files found.
Too many changes to show.
To preserve performance only
421 of 421+
files are displayed.
Plain diff
Email patch
tests/interactions/test_gsm8k_interaction.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
unittest.mock
import
patch
import
pytest
from
verl.interactions.gsm8k_interaction
import
Gsm8kInteraction
class
TestGsm8kInteraction
:
"""Test cases for Gsm8kInteraction class."""
def
setup_method
(
self
):
"""Set up test environment before each test method."""
self
.
config
=
{
"name"
:
"gsm8k"
}
self
.
interaction
=
Gsm8kInteraction
(
self
.
config
)
def
test_init
(
self
):
"""Test Gsm8kInteraction initialization."""
assert
self
.
interaction
.
_instance_dict
==
{}
assert
self
.
interaction
.
config
==
self
.
config
assert
self
.
interaction
.
name
==
"gsm8k"
@
pytest
.
mark
.
asyncio
async
def
test_start_interaction_with_instance_id
(
self
):
"""Test start_interaction with provided instance_id."""
instance_id
=
"test_instance"
ground_truth
=
"42"
result_id
=
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
assert
result_id
==
instance_id
assert
instance_id
in
self
.
interaction
.
_instance_dict
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
""
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"ground_truth"
]
==
ground_truth
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"reward"
]
==
0.0
@
pytest
.
mark
.
asyncio
async
def
test_start_interaction_without_instance_id
(
self
):
"""Test start_interaction without provided instance_id (auto-generated)."""
ground_truth
=
"42"
result_id
=
await
self
.
interaction
.
start_interaction
(
ground_truth
=
ground_truth
)
assert
result_id
is
not
None
assert
len
(
result_id
)
==
36
# UUID4 length
assert
result_id
in
self
.
interaction
.
_instance_dict
assert
self
.
interaction
.
_instance_dict
[
result_id
][
"ground_truth"
]
==
ground_truth
@
pytest
.
mark
.
asyncio
async
def
test_start_interaction_without_ground_truth
(
self
):
"""Test start_interaction without ground_truth parameter."""
instance_id
=
"test_instance"
result_id
=
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
)
assert
result_id
==
instance_id
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"ground_truth"
]
is
None
@
pytest
.
mark
.
asyncio
async
def
test_generate_response_correct_answer_with_prefix
(
self
):
"""Test generate_response with correct answer already having #### prefix."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
messages
=
[{
"role"
:
"user"
,
"content"
:
"#### 42"
}]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
1.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
True
assert
response
==
"Your response is correct!"
assert
reward
==
1.0
assert
metadata
==
{}
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
"#### 42"
@
pytest
.
mark
.
asyncio
async
def
test_generate_response_correct_answer_without_prefix
(
self
):
"""Test generate_response with correct answer missing #### prefix."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
messages
=
[{
"role"
:
"user"
,
"content"
:
"42"
}]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
1.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
True
assert
response
==
"Your response is correct!"
assert
reward
==
1.0
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
"#### 42"
@
pytest
.
mark
.
asyncio
async
def
test_generate_response_incorrect_answer
(
self
):
"""Test generate_response with incorrect answer."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
messages
=
[{
"role"
:
"user"
,
"content"
:
"24"
}]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
0.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
False
assert
response
==
"Your response is incorrect! You need to reflect on your answer and try again."
assert
reward
==
0.0
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
"#### 24"
@
pytest
.
mark
.
asyncio
async
def
test_generate_response_multiple_messages
(
self
):
"""Test generate_response with multiple messages (should use last user message)."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
messages
=
[
{
"role"
:
"user"
,
"content"
:
"What is 2+2?"
},
{
"role"
:
"assistant"
,
"content"
:
"Let me think about this..."
},
{
"role"
:
"user"
,
"content"
:
"#### 42"
},
]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
1.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
True
assert
response
==
"Your response is correct!"
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
"#### 42"
@
pytest
.
mark
.
asyncio
async
def
test_generate_response_no_user_message
(
self
):
"""Test generate_response with no user messages."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
messages
=
[{
"role"
:
"assistant"
,
"content"
:
"Hello!"
}]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
0.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
False
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
"#### "
@
pytest
.
mark
.
asyncio
async
def
test_calculate_score_direct_call
(
self
):
"""Test calculate_score method directly."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
# Set a response
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
=
"#### 42"
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
1.0
)
as
mock_compute
:
score
=
await
self
.
interaction
.
calculate_score
(
instance_id
)
assert
score
==
1.0
mock_compute
.
assert_called_once_with
(
"#### 42"
,
"42"
,
method
=
"flexible"
,
format_score
=
0.0
,
score
=
1.0
)
@
pytest
.
mark
.
asyncio
async
def
test_calculate_score_with_kwargs
(
self
):
"""Test calculate_score method with additional kwargs."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
# Set a response
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
=
"#### 24"
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
0.0
)
as
mock_compute
:
score
=
await
self
.
interaction
.
calculate_score
(
instance_id
,
extra_param
=
"test"
)
assert
score
==
0.0
mock_compute
.
assert_called_once_with
(
"#### 24"
,
"42"
,
method
=
"flexible"
,
format_score
=
0.0
,
score
=
1.0
)
@
pytest
.
mark
.
asyncio
async
def
test_finalize_interaction
(
self
):
"""Test finalize_interaction method."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
assert
instance_id
in
self
.
interaction
.
_instance_dict
await
self
.
interaction
.
finalize_interaction
(
instance_id
)
assert
instance_id
not
in
self
.
interaction
.
_instance_dict
@
pytest
.
mark
.
asyncio
async
def
test_finalize_interaction_with_kwargs
(
self
):
"""Test finalize_interaction method with additional kwargs."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
assert
instance_id
in
self
.
interaction
.
_instance_dict
await
self
.
interaction
.
finalize_interaction
(
instance_id
,
extra_param
=
"test"
)
assert
instance_id
not
in
self
.
interaction
.
_instance_dict
@
pytest
.
mark
.
asyncio
async
def
test_finalize_nonexistent_interaction
(
self
):
"""Test finalize_interaction with non-existent instance_id."""
instance_id
=
"nonexistent_instance"
# This should raise KeyError
with
pytest
.
raises
(
KeyError
):
await
self
.
interaction
.
finalize_interaction
(
instance_id
)
@
pytest
.
mark
.
asyncio
async
def
test_full_interaction_workflow_correct
(
self
):
"""Test complete interaction workflow with correct answer."""
ground_truth
=
"42"
# Start interaction
instance_id
=
await
self
.
interaction
.
start_interaction
(
ground_truth
=
ground_truth
)
# Generate response with correct answer
messages
=
[{
"role"
:
"user"
,
"content"
:
"42"
}]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
1.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
True
assert
reward
==
1.0
# Finalize interaction
await
self
.
interaction
.
finalize_interaction
(
instance_id
)
assert
instance_id
not
in
self
.
interaction
.
_instance_dict
@
pytest
.
mark
.
asyncio
async
def
test_full_interaction_workflow_incorrect
(
self
):
"""Test complete interaction workflow with incorrect answer."""
ground_truth
=
"42"
# Start interaction
instance_id
=
await
self
.
interaction
.
start_interaction
(
ground_truth
=
ground_truth
)
# Generate response with incorrect answer
messages
=
[{
"role"
:
"user"
,
"content"
:
"24"
}]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
0.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
False
assert
reward
==
0.0
# Continue with another attempt
messages
.
append
({
"role"
:
"assistant"
,
"content"
:
response
})
messages
.
append
({
"role"
:
"user"
,
"content"
:
"42"
})
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
1.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
True
assert
reward
==
1.0
# Finalize interaction
await
self
.
interaction
.
finalize_interaction
(
instance_id
)
assert
instance_id
not
in
self
.
interaction
.
_instance_dict
@
pytest
.
mark
.
asyncio
async
def
test_multiple_concurrent_interactions
(
self
):
"""Test multiple concurrent interaction instances."""
ground_truth_1
=
"42"
ground_truth_2
=
"24"
# Start multiple interactions
instance_id_1
=
await
self
.
interaction
.
start_interaction
(
ground_truth
=
ground_truth_1
)
instance_id_2
=
await
self
.
interaction
.
start_interaction
(
ground_truth
=
ground_truth_2
)
assert
len
(
self
.
interaction
.
_instance_dict
)
==
2
assert
instance_id_1
in
self
.
interaction
.
_instance_dict
assert
instance_id_2
in
self
.
interaction
.
_instance_dict
# Test responses for both instances
messages_1
=
[{
"role"
:
"user"
,
"content"
:
"42"
}]
messages_2
=
[{
"role"
:
"user"
,
"content"
:
"24"
}]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
side_effect
=
[
1.0
,
1.0
]):
should_terminate_1
,
_
,
reward_1
,
_
=
await
self
.
interaction
.
generate_response
(
instance_id_1
,
messages_1
)
should_terminate_2
,
_
,
reward_2
,
_
=
await
self
.
interaction
.
generate_response
(
instance_id_2
,
messages_2
)
assert
should_terminate_1
is
True
assert
should_terminate_2
is
True
assert
reward_1
==
1.0
assert
reward_2
==
1.0
# Finalize both interactions
await
self
.
interaction
.
finalize_interaction
(
instance_id_1
)
await
self
.
interaction
.
finalize_interaction
(
instance_id_2
)
assert
len
(
self
.
interaction
.
_instance_dict
)
==
0
@
pytest
.
mark
.
asyncio
async
def
test_edge_case_empty_messages
(
self
):
"""Test edge case with empty messages list."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
messages
=
[]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
0.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
False
assert
reward
==
0.0
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
"#### "
@
pytest
.
mark
.
asyncio
async
def
test_edge_case_message_without_content
(
self
):
"""Test edge case with message without content field."""
instance_id
=
"test_instance"
ground_truth
=
"42"
# Setup instance
await
self
.
interaction
.
start_interaction
(
instance_id
=
instance_id
,
ground_truth
=
ground_truth
)
messages
=
[
{
"role"
:
"user"
}
# Missing content field
]
with
patch
(
"verl.utils.reward_score.gsm8k.compute_score"
,
return_value
=
0.0
):
should_terminate
,
response
,
reward
,
metadata
=
await
self
.
interaction
.
generate_response
(
instance_id
,
messages
)
assert
should_terminate
is
False
assert
reward
==
0.0
assert
self
.
interaction
.
_instance_dict
[
instance_id
][
"response"
]
==
"#### None"
def
test_inheritance_from_base_interaction
(
self
):
"""Test that Gsm8kInteraction properly inherits from BaseInteraction."""
from
verl.interactions.base
import
BaseInteraction
assert
isinstance
(
self
.
interaction
,
BaseInteraction
)
# Test that all required methods are implemented
assert
hasattr
(
self
.
interaction
,
"start_interaction"
)
assert
hasattr
(
self
.
interaction
,
"generate_response"
)
assert
hasattr
(
self
.
interaction
,
"calculate_score"
)
assert
hasattr
(
self
.
interaction
,
"finalize_interaction"
)
# Test that methods are callable
assert
callable
(
self
.
interaction
.
start_interaction
)
assert
callable
(
self
.
interaction
.
generate_response
)
assert
callable
(
self
.
interaction
.
calculate_score
)
assert
callable
(
self
.
interaction
.
finalize_interaction
)
def
test_name_attribute_initialization
(
self
):
"""Test name attribute initialization with different configs."""
# Test with explicit name in config
config_with_name
=
{
"name"
:
"custom_gsm8k"
}
interaction_with_name
=
Gsm8kInteraction
(
config_with_name
)
assert
interaction_with_name
.
name
==
"custom_gsm8k"
# Test with default name when not provided in config
config_without_name
=
{}
interaction_without_name
=
Gsm8kInteraction
(
config_without_name
)
assert
interaction_without_name
.
name
==
"interaction_agent"
# Default from BaseInteraction
# Test that name is accessible as attribute
assert
hasattr
(
self
.
interaction
,
"name"
)
assert
self
.
interaction
.
name
==
"gsm8k"
tests/interactions/test_interaction_registry.py
0 → 100644
View file @
7f6cc211
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
tempfile
import
pytest
from
omegaconf
import
OmegaConf
from
verl.interactions.base
import
BaseInteraction
from
verl.interactions.gsm8k_interaction
import
Gsm8kInteraction
from
verl.interactions.utils.interaction_registry
import
(
get_interaction_class
,
initialize_interactions_from_config
,
)
class
TestInteractionRegistry
:
def
test_get_interaction_class
(
self
):
"""Test getting interaction class by name."""
# Test getting base interaction class
base_cls
=
get_interaction_class
(
"verl.interactions.base.BaseInteraction"
)
assert
base_cls
==
BaseInteraction
# Test getting gsm8k interaction class
gsm8k_cls
=
get_interaction_class
(
"verl.interactions.gsm8k_interaction.Gsm8kInteraction"
)
assert
gsm8k_cls
==
Gsm8kInteraction
def
test_initialize_single_interaction_from_config
(
self
):
"""Test initializing single interaction from config."""
# Create temporary config file
config_content
=
{
"interaction"
:
[
{
"name"
:
"test_gsm8k"
,
"class_name"
:
"verl.interactions.gsm8k_interaction.Gsm8kInteraction"
,
"config"
:
{},
}
]
}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".yaml"
,
delete
=
False
)
as
f
:
OmegaConf
.
save
(
config_content
,
f
.
name
)
temp_config_path
=
f
.
name
try
:
interaction_map
=
initialize_interactions_from_config
(
temp_config_path
)
# Check that interaction was created
assert
len
(
interaction_map
)
==
1
assert
"test_gsm8k"
in
interaction_map
assert
isinstance
(
interaction_map
[
"test_gsm8k"
],
Gsm8kInteraction
)
assert
interaction_map
[
"test_gsm8k"
].
name
==
"test_gsm8k"
finally
:
os
.
unlink
(
temp_config_path
)
def
test_initialize_multiple_interactions_from_config
(
self
):
"""Test initializing multiple interactions from config."""
config_content
=
{
"interaction"
:
[
{
"name"
:
"gsm8k_solver"
,
"class_name"
:
"verl.interactions.gsm8k_interaction.Gsm8kInteraction"
,
"config"
:
{},
},
{
"name"
:
"base_agent"
,
"class_name"
:
"verl.interactions.base.BaseInteraction"
,
"config"
:
{
"custom_param"
:
"test_value"
},
},
]
}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".yaml"
,
delete
=
False
)
as
f
:
OmegaConf
.
save
(
config_content
,
f
.
name
)
temp_config_path
=
f
.
name
try
:
interaction_map
=
initialize_interactions_from_config
(
temp_config_path
)
# Check that both interactions were created
assert
len
(
interaction_map
)
==
2
assert
"gsm8k_solver"
in
interaction_map
assert
"base_agent"
in
interaction_map
# Check types
assert
isinstance
(
interaction_map
[
"gsm8k_solver"
],
Gsm8kInteraction
)
assert
isinstance
(
interaction_map
[
"base_agent"
],
BaseInteraction
)
# Check names were injected
assert
interaction_map
[
"gsm8k_solver"
].
name
==
"gsm8k_solver"
assert
interaction_map
[
"base_agent"
].
name
==
"base_agent"
# Check custom config was passed
assert
interaction_map
[
"base_agent"
].
config
.
get
(
"custom_param"
)
==
"test_value"
finally
:
os
.
unlink
(
temp_config_path
)
def
test_initialize_interaction_without_explicit_name
(
self
):
"""Test that interaction name is derived from class name when not specified."""
config_content
=
{
"interaction"
:
[{
"class_name"
:
"verl.interactions.gsm8k_interaction.Gsm8kInteraction"
,
"config"
:
{}}]
}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".yaml"
,
delete
=
False
)
as
f
:
OmegaConf
.
save
(
config_content
,
f
.
name
)
temp_config_path
=
f
.
name
try
:
interaction_map
=
initialize_interactions_from_config
(
temp_config_path
)
# Check that interaction name was derived from class name
assert
len
(
interaction_map
)
==
1
assert
"gsm8k"
in
interaction_map
# Should be "gsm8k" after removing "interaction" suffix
assert
isinstance
(
interaction_map
[
"gsm8k"
],
Gsm8kInteraction
)
assert
interaction_map
[
"gsm8k"
].
name
==
"gsm8k"
finally
:
os
.
unlink
(
temp_config_path
)
def
test_initialize_empty_config
(
self
):
"""Test initializing from empty config."""
config_content
=
{
"interaction"
:
[]}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".yaml"
,
delete
=
False
)
as
f
:
OmegaConf
.
save
(
config_content
,
f
.
name
)
temp_config_path
=
f
.
name
try
:
interaction_map
=
initialize_interactions_from_config
(
temp_config_path
)
assert
len
(
interaction_map
)
==
0
finally
:
os
.
unlink
(
temp_config_path
)
def
test_invalid_class_name
(
self
):
"""Test handling of invalid class name."""
config_content
=
{
"interaction"
:
[{
"name"
:
"invalid"
,
"class_name"
:
"invalid.module.InvalidClass"
,
"config"
:
{}}]
}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".yaml"
,
delete
=
False
)
as
f
:
OmegaConf
.
save
(
config_content
,
f
.
name
)
temp_config_path
=
f
.
name
try
:
with
pytest
.
raises
(
ModuleNotFoundError
):
initialize_interactions_from_config
(
temp_config_path
)
finally
:
os
.
unlink
(
temp_config_path
)
def
test_duplicate_interaction_names
(
self
):
"""Test handling of duplicate interaction names."""
config_content
=
{
"interaction"
:
[
{
"name"
:
"duplicate"
,
"class_name"
:
"verl.interactions.base.BaseInteraction"
,
"config"
:
{}},
{
"name"
:
"duplicate"
,
"class_name"
:
"verl.interactions.gsm8k_interaction.Gsm8kInteraction"
,
"config"
:
{},
},
]
}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".yaml"
,
delete
=
False
)
as
f
:
OmegaConf
.
save
(
config_content
,
f
.
name
)
temp_config_path
=
f
.
name
try
:
with
pytest
.
raises
(
ValueError
,
match
=
"Duplicate interaction name 'duplicate' found"
):
initialize_interactions_from_config
(
temp_config_path
)
finally
:
os
.
unlink
(
temp_config_path
)
def
test_auto_name_generation_edge_cases
(
self
):
"""Test automatic name generation for various class name patterns."""
config_content
=
{
"interaction"
:
[
{
"class_name"
:
"verl.interactions.base.BaseInteraction"
,
"config"
:
{}},
{
"class_name"
:
"verl.interactions.gsm8k_interaction.Gsm8kInteraction"
,
"config"
:
{}},
]
}
with
tempfile
.
NamedTemporaryFile
(
mode
=
"w"
,
suffix
=
".yaml"
,
delete
=
False
)
as
f
:
OmegaConf
.
save
(
config_content
,
f
.
name
)
temp_config_path
=
f
.
name
try
:
interaction_map
=
initialize_interactions_from_config
(
temp_config_path
)
# Check that names were generated correctly
assert
len
(
interaction_map
)
==
2
assert
"base"
in
interaction_map
# BaseInteraction -> base
assert
"gsm8k"
in
interaction_map
# Gsm8kInteraction -> gsm8k
finally
:
os
.
unlink
(
temp_config_path
)
tests/kill_github_tests.sh
0 → 100644
View file @
7f6cc211
#!/bin/bash
if
[
"$#"
-ne
1
]
;
then
echo
"Usage:
$0
YOUR_GITHUB_TOKEN"
echo
"Please provide exactly one input argument for your github token."
exit
1
fi
# Set your GitHub repository details
OWNER
=
"volcengine"
REPO
=
"verl"
TOKEN
=
$1
# API URL for workflow runs
API_URL
=
"https://api.github.com/repos/
$OWNER
/
$REPO
/actions/runs?status=queued"
# Check required commands
command
-v
jq
>
/dev/null 2>&1
||
{
echo
"jq is required but not installed. Aborting."
;
exit
1
;
}
# Get queued workflow runs
response
=
$(
curl
-s
-H
"Authorization: token
$TOKEN
"
-H
"Accept: application/vnd.github.v3+json"
"
$API_URL
"
)
# Run this for debugging
# echo $response
# Extract run IDs
queued_run_ids
=
$(
echo
"
$response
"
| jq
-r
'.workflow_runs[] | .id'
)
if
[
-z
"
$queued_run_ids
"
]
;
then
echo
"No queued workflow runs found."
exit
0
fi
# Cancel each queued run
for
run_id
in
$queued_run_ids
;
do
echo
"Cancelling run
$run_id
"
cancel_url
=
"https://api.github.com/repos/
$OWNER
/
$REPO
/actions/runs/
$run_id
/cancel"
curl
-s
-X
POST
-H
"Authorization: token
$TOKEN
"
-H
"Accept: application/vnd.github.v3+json"
"
$cancel_url
"
done
echo
"Cancelled all queued workflow runs."
tests/models/test_transformer.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
rearrange
,
unpad_input
from
transformers
import
(
AutoModelForCausalLM
,
AutoModelForTokenClassification
,
GemmaConfig
,
LlamaConfig
,
MistralConfig
,
Qwen2Config
,
)
from
verl.utils.model
import
compute_position_id_with_mask
,
create_random_mask
from
verl.utils.torch_functional
import
log_probs_from_logits_all_rmpad
,
masked_mean
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs
=
[
LlamaConfig
(
num_hidden_layers
=
1
),
MistralConfig
(
num_hidden_layers
=
1
),
GemmaConfig
(
num_hidden_layers
=
1
),
Qwen2Config
(
num_hidden_layers
=
1
),
]
def
test_hf_casual_models
():
batch_size
=
4
seqlen
=
128
response_length
=
127
for
config
in
test_configs
:
# config = AutoConfig.from_pretrained(test_case)
with
torch
.
device
(
"cuda"
):
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
"flash_attention_2"
)
model
=
model
.
to
(
device
=
"cuda"
)
input_ids
=
torch
.
randint
(
low
=
0
,
high
=
config
.
vocab_size
,
size
=
(
batch_size
,
seqlen
),
device
=
"cuda"
)
attention_mask
=
create_random_mask
(
input_ids
=
input_ids
,
max_ratio_of_left_padding
=
0.1
,
max_ratio_of_valid_token
=
0.8
,
min_ratio_of_valid_token
=
0.5
,
)
position_ids
=
compute_position_id_with_mask
(
attention_mask
)
# TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad
,
indices
,
*
_
=
unpad_input
(
input_ids
.
unsqueeze
(
-
1
),
attention_mask
)
# input_ids_rmpad (total_nnz, ...)
input_ids_rmpad
=
input_ids_rmpad
.
transpose
(
0
,
1
)
# (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad
=
index_first_axis
(
rearrange
(
position_ids
.
unsqueeze
(
-
1
),
"b s ... -> (b s) ..."
),
indices
).
transpose
(
0
,
1
)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_rmpad
=
model
(
input_ids_rmpad
,
position_ids
=
position_ids_rmpad
,
use_cache
=
False
).
logits
# (1, total_nnz, vocab_size)
origin_logits
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
use_cache
=
False
).
logits
origin_logits_rmpad
,
origin_logits_indices
,
*
_
=
unpad_input
(
origin_logits
,
attention_mask
)
logits_rmpad
=
logits_rmpad
.
squeeze
(
0
)
log_probs
=
log_probs_from_logits_all_rmpad
(
input_ids_rmpad
=
input_ids_rmpad
,
logits_rmpad
=
logits_rmpad
,
indices
=
indices
,
batch_size
=
batch_size
,
seqlen
=
seqlen
,
response_length
=
response_length
,
)
# (batch, seqlen)
origin_log_probs
=
log_probs_from_logits_all_rmpad
(
input_ids_rmpad
=
input_ids_rmpad
,
logits_rmpad
=
origin_logits_rmpad
,
indices
=
origin_logits_indices
,
batch_size
=
batch_size
,
seqlen
=
seqlen
,
response_length
=
response_length
,
)
# (batch, seqlen)
torch
.
testing
.
assert_close
(
masked_mean
(
log_probs
,
attention_mask
[:,
-
response_length
-
1
:
-
1
]),
masked_mean
(
origin_log_probs
,
attention_mask
[:,
-
response_length
-
1
:
-
1
]),
atol
=
1e-2
,
rtol
=
1e-5
,
)
print
(
"Check pass"
)
def
test_hf_value_models
():
batch_size
=
4
seqlen
=
128
for
config
in
test_configs
:
# config = AutoConfig.from_pretrained(test_case)
config
.
num_labels
=
1
config
.
classifier_dropout
=
0
config
.
hidden_dropout
=
0
with
torch
.
device
(
"cuda"
):
model
=
AutoModelForTokenClassification
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
"flash_attention_2"
)
model
=
model
.
to
(
device
=
"cuda"
)
input_ids
=
torch
.
randint
(
low
=
0
,
high
=
config
.
vocab_size
,
size
=
(
batch_size
,
seqlen
),
device
=
"cuda"
)
attention_mask
=
create_random_mask
(
input_ids
=
input_ids
,
max_ratio_of_left_padding
=
0.1
,
max_ratio_of_valid_token
=
0.8
,
min_ratio_of_valid_token
=
0.5
,
)
position_ids
=
compute_position_id_with_mask
(
attention_mask
)
# TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad
,
indices
,
*
_
=
unpad_input
(
input_ids
.
unsqueeze
(
-
1
),
attention_mask
)
# input_ids_rmpad (total_nnz, ...)
input_ids_rmpad
=
input_ids_rmpad
.
transpose
(
0
,
1
)
# (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad
=
index_first_axis
(
rearrange
(
position_ids
.
unsqueeze
(
-
1
),
"b s ... -> (b s) ..."
),
indices
).
transpose
(
0
,
1
)
origin_logits
=
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
use_cache
=
False
).
logits
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
rmpad_logits
=
model
(
input_ids_rmpad
,
position_ids
=
position_ids_rmpad
,
use_cache
=
False
).
logits
# (1, total_nnz, 1)
rmpad_logits
=
rmpad_logits
.
squeeze
(
0
)
pad_logits
=
pad_input
(
rmpad_logits
,
indices
,
batch_size
,
seqlen
=
seqlen
)
torch
.
testing
.
assert_close
(
masked_mean
(
pad_logits
,
attention_mask
[:,
:,
None
]),
masked_mean
(
origin_logits
,
attention_mask
[:,
:,
None
]),
atol
=
1e-2
,
rtol
=
1e-5
,
)
print
(
"Value model check pass"
)
if
__name__
==
"__main__"
:
test_hf_casual_models
()
test_hf_value_models
()
tests/models/test_transformers_ulysses.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
contextlib
import
copy
from
dataclasses
import
dataclass
import
pytest
import
torch
import
torch.distributed
from
flash_attn.bert_padding
import
index_first_axis
,
rearrange
,
unpad_input
from
torch.distributed
import
init_device_mesh
from
transformers
import
AutoModelForCausalLM
,
LlamaConfig
,
PretrainedConfig
,
Qwen2Config
from
verl.models.transformers.monkey_patch
import
apply_monkey_patch
from
verl.protocol
import
DataProto
from
verl.utils.distributed
import
initialize_global_process_group
from
verl.utils.model
import
compute_position_id_with_mask
,
create_random_mask
from
verl.utils.ulysses
import
(
gather_outputs_and_unpad
,
get_ulysses_sequence_parallel_world_size
,
set_ulysses_sequence_parallel_group
,
ulysses_pad_and_slice_inputs
,
)
from
verl.workers.sharding_manager.fsdp_ulysses
import
FSDPUlyssesShardingManager
# TODO(sgm): add more models for test
# we only need one scale for each model
@
dataclass
class
SequenceParallelConfig
:
config
:
PretrainedConfig
sp_size
:
int
is_valid
:
bool
def
test_configs
():
return
[
SequenceParallelConfig
(
LlamaConfig
(
num_hidden_layers
=
2
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
),
sp_size
=
8
,
is_valid
=
True
),
SequenceParallelConfig
(
Qwen2Config
(
num_hidden_layers
=
2
,
num_attention_heads
=
28
,
num_key_value_heads
=
4
,
hidden_size
=
3584
),
sp_size
=
4
,
is_valid
=
True
,
),
SequenceParallelConfig
(
Qwen2Config
(
num_hidden_layers
=
2
,
num_attention_heads
=
28
,
num_key_value_heads
=
4
,
hidden_size
=
3584
),
sp_size
=
8
,
is_valid
=
False
,
),
SequenceParallelConfig
(
Qwen2Config
(
num_hidden_layers
=
2
,
num_attention_heads
=
32
,
num_key_value_heads
=
4
),
sp_size
=
4
,
is_valid
=
True
),
SequenceParallelConfig
(
Qwen2Config
(
num_hidden_layers
=
2
,
num_attention_heads
=
32
,
num_key_value_heads
=
4
),
sp_size
=
8
,
is_valid
=
True
),
]
def
sync_model_parameters_global
(
layer
):
# synchronize weights
for
p
in
layer
.
parameters
():
torch
.
distributed
.
broadcast
(
tensor
=
p
.
data
,
src
=
0
)
@
pytest
.
mark
.
parametrize
(
"test_config"
,
test_configs
())
def
test_hf_casual_fwd_bwd
(
test_config
):
if
not
torch
.
distributed
.
is_initialized
():
initialize_global_process_group
()
context
=
contextlib
.
nullcontext
()
if
test_config
.
is_valid
else
pytest
.
raises
(
AssertionError
)
with
context
:
world_size
=
torch
.
distributed
.
get_world_size
()
_hf_casual_fwd_bwd
(
test_config
.
config
,
test_config
.
sp_size
,
world_size
//
test_config
.
sp_size
)
# TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort`
# torch.distributed.destroy_process_group()
def
_hf_casual_fwd
(
config
,
sp_size
,
dp_size
):
assert
torch
.
cuda
.
device_count
()
>=
2
,
"need at least 2 gpus for test"
ulysses_device_mesh
=
init_device_mesh
(
device_type
=
"cuda"
,
mesh_shape
=
(
dp_size
,
sp_size
),
mesh_dim_names
=
(
"dp"
,
"sp"
)
)
sharding_manager
=
FSDPUlyssesShardingManager
(
ulysses_device_mesh
)
batch_size
=
1
seqlen
=
128
# response_length = 127
# patch before load
with
torch
.
device
(
"cuda"
):
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
"flash_attention_2"
)
apply_monkey_patch
(
model
,
sp_size
)
model
=
model
.
to
(
device
=
"cuda"
)
sync_model_parameters_global
(
model
)
# different rank will generate different input_ids following fsdp
input_ids
=
torch
.
randint
(
low
=
0
,
high
=
config
.
vocab_size
,
size
=
(
batch_size
,
seqlen
),
device
=
"cuda"
)
attention_mask
=
create_random_mask
(
input_ids
=
input_ids
,
max_ratio_of_left_padding
=
0
,
max_ratio_of_valid_token
=
0.9
,
min_ratio_of_valid_token
=
0.8
)
position_ids
=
compute_position_id_with_mask
(
attention_mask
)
# TODO(sgm): we can construct the position_ids_rmpad here
model_inputs
=
{
"input_ids"
:
input_ids
.
cuda
(),
"attention_mask"
:
attention_mask
.
cuda
(),
"position_ids"
:
position_ids
.
int
().
cuda
(),
}
model_inputs
=
DataProto
.
from_dict
(
model_inputs
)
# 1. perform ulysses forward
with
sharding_manager
:
model_inputs
=
sharding_manager
.
preprocess_data
(
model_inputs
)
input_ids
=
model_inputs
.
batch
[
"input_ids"
]
attention_mask
=
model_inputs
.
batch
[
"attention_mask"
]
position_ids
=
model_inputs
.
batch
[
"position_ids"
]
input_ids_rmpad
,
indices
,
*
_
=
unpad_input
(
input_ids
.
unsqueeze
(
-
1
),
attention_mask
)
# input_ids_rmpad (total_nnz, ...)
input_ids_rmpad
=
input_ids_rmpad
.
transpose
(
0
,
1
)
# (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad
=
index_first_axis
(
rearrange
(
position_ids
.
unsqueeze
(
-
1
),
"b s ... -> (b s) ..."
),
indices
).
transpose
(
0
,
1
)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced
,
position_ids_rmpad_padded
,
pad_size
=
ulysses_pad_and_slice_inputs
(
input_ids_rmpad
,
position_ids_rmpad
,
sp_size
=
get_ulysses_sequence_parallel_world_size
()
)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq
=
model
(
input_ids_rmpad_sliced
,
position_ids
=
position_ids_rmpad_padded
,
use_cache
=
False
).
logits
# (1, total_nnz/n, vocab_size)
# all_gather output
logits_full
=
gather_outputs_and_unpad
(
logits_split_in_seq
,
gather_dim
=
1
,
unpad_dim
=
1
,
padding_size
=
pad_size
)
# 2. perform normal forward
set_ulysses_sequence_parallel_group
(
None
)
logits_rmpad_local
=
model
(
input_ids_rmpad
,
position_ids
=
position_ids_rmpad
,
use_cache
=
False
).
logits
# (1, total_nnz, vocab_size)
mean_local
=
logits_rmpad_local
.
mean
()
mean_full
=
logits_full
.
mean
()
torch
.
testing
.
assert_close
(
mean_local
,
mean_full
,
rtol
=
1e-2
,
atol
=
1e-5
)
def
_hf_casual_fwd_bwd
(
config
,
sp_size
,
dp_size
):
assert
torch
.
cuda
.
device_count
()
>=
2
,
"need at least 2 gpus for test"
ulysses_device_mesh
=
init_device_mesh
(
device_type
=
"cuda"
,
mesh_shape
=
(
dp_size
,
sp_size
),
mesh_dim_names
=
(
"dp"
,
"sp"
)
)
sharding_manager
=
FSDPUlyssesShardingManager
(
ulysses_device_mesh
)
batch_size
=
1
seqlen
=
128
# response_length = 127
# patch before load
with
torch
.
device
(
"cuda"
):
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
bfloat16
,
attn_implementation
=
"flash_attention_2"
)
apply_monkey_patch
(
model
,
sp_size
)
model
=
model
.
to
(
device
=
"cuda"
)
sync_model_parameters_global
(
model
)
# different rank will generate different input_ids following fsdp
input_ids
=
torch
.
randint
(
low
=
0
,
high
=
config
.
vocab_size
,
size
=
(
batch_size
,
seqlen
),
device
=
"cuda"
)
attention_mask
=
create_random_mask
(
input_ids
=
input_ids
,
max_ratio_of_left_padding
=
0
,
max_ratio_of_valid_token
=
0.9
,
min_ratio_of_valid_token
=
0.8
)
position_ids
=
compute_position_id_with_mask
(
attention_mask
)
# TODO(sgm): we can construct the position_ids_rmpad here
model_inputs
=
{
"input_ids"
:
input_ids
.
cuda
(),
"attention_mask"
:
attention_mask
.
cuda
(),
"position_ids"
:
position_ids
.
int
().
cuda
(),
}
model_inputs
=
DataProto
.
from_dict
(
model_inputs
)
# 1. perform ulysses forward
with
sharding_manager
:
model_inputs
=
sharding_manager
.
preprocess_data
(
model_inputs
)
input_ids
=
model_inputs
.
batch
[
"input_ids"
]
attention_mask
=
model_inputs
.
batch
[
"attention_mask"
]
position_ids
=
model_inputs
.
batch
[
"position_ids"
]
input_ids_rmpad
,
indices
,
*
_
=
unpad_input
(
input_ids
.
unsqueeze
(
-
1
),
attention_mask
)
# input_ids_rmpad (total_nnz, ...)
input_ids_rmpad
=
input_ids_rmpad
.
transpose
(
0
,
1
)
# (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad
=
index_first_axis
(
rearrange
(
position_ids
.
unsqueeze
(
-
1
),
"b s ... -> (b s) ..."
),
indices
).
transpose
(
0
,
1
)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced
,
position_ids_rmpad_padded
,
pad_size
=
ulysses_pad_and_slice_inputs
(
input_ids_rmpad
,
position_ids_rmpad
,
sp_size
=
get_ulysses_sequence_parallel_world_size
()
)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq
=
model
(
input_ids_rmpad_sliced
,
position_ids
=
position_ids_rmpad_padded
,
use_cache
=
False
).
logits
# (1, total_nnz/n, vocab_size)
# all_gather output
logits_full
=
gather_outputs_and_unpad
(
logits_split_in_seq
,
gather_dim
=
1
,
unpad_dim
=
1
,
padding_size
=
pad_size
)
# 2. perform normal forward
set_ulysses_sequence_parallel_group
(
None
)
input_ids_full
=
copy
.
deepcopy
(
input_ids_rmpad
)
position_ids_full
=
copy
.
deepcopy
(
position_ids_rmpad
)
model_no_sp
=
copy
.
deepcopy
(
model
)
logits_rmpad_local
=
model_no_sp
(
input_ids_full
,
position_ids
=
position_ids_full
,
use_cache
=
False
).
logits
# (1, total_nnz, vocab_size)
mean_local
=
logits_rmpad_local
.
mean
()
mean_full
=
logits_full
.
mean
()
mean_full
.
backward
()
mean_local
.
backward
()
# 3. check the gradients
grad
=
model
.
model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
grad_full
=
model_no_sp
.
model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
torch
.
testing
.
assert_close
(
mean_local
,
mean_full
,
rtol
=
1e-2
,
atol
=
1e-5
)
torch
.
testing
.
assert_close
(
grad
,
grad_full
,
atol
=
1e-2
,
rtol
=
1e-5
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-svv"
])
tests/single_controller/__init__.py
0 → 100644
View file @
7f6cc211
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
tests/single_controller/base/test_decorator.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pytest
import
verl.single_controller.base.decorator
as
decorator_module
from
verl.single_controller.base.decorator
import
(
DISPATCH_MODE_FN_REGISTRY
,
Dispatch
,
_check_dispatch_mode
,
get_predefined_dispatch_fn
,
register_dispatch_mode
,
update_dispatch_mode
,
)
@
pytest
.
fixture
def
reset_dispatch_registry
():
# Store original state
original_registry
=
DISPATCH_MODE_FN_REGISTRY
.
copy
()
yield
# Reset registry after test
decorator_module
.
DISPATCH_MODE_FN_REGISTRY
.
clear
()
decorator_module
.
DISPATCH_MODE_FN_REGISTRY
.
update
(
original_registry
)
def
test_register_new_dispatch_mode
(
reset_dispatch_registry
):
# Test registration
def
dummy_dispatch
(
worker_group
,
*
args
,
**
kwargs
):
return
args
,
kwargs
def
dummy_collect
(
worker_group
,
output
):
return
output
register_dispatch_mode
(
"TEST_MODE"
,
dummy_dispatch
,
dummy_collect
)
# Verify enum extension
_check_dispatch_mode
(
Dispatch
.
TEST_MODE
)
# Verify registry update
assert
get_predefined_dispatch_fn
(
Dispatch
.
TEST_MODE
)
==
{
"dispatch_fn"
:
dummy_dispatch
,
"collect_fn"
:
dummy_collect
,
}
# Clean up
Dispatch
.
remove
(
"TEST_MODE"
)
def
test_update_existing_dispatch_mode
(
reset_dispatch_registry
):
# Store original implementation
original_mode
=
Dispatch
.
ONE_TO_ALL
# New implementations
def
new_dispatch
(
worker_group
,
*
args
,
**
kwargs
):
return
args
,
kwargs
def
new_collect
(
worker_group
,
output
):
return
output
# Test update=
update_dispatch_mode
(
original_mode
,
new_dispatch
,
new_collect
)
# Verify update
assert
get_predefined_dispatch_fn
(
original_mode
)[
"dispatch_fn"
]
==
new_dispatch
assert
get_predefined_dispatch_fn
(
original_mode
)[
"collect_fn"
]
==
new_collect
tests/single_controller/check_worker_alive/main.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
import
time
import
ray
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.base.worker
import
Worker
from
verl.single_controller.ray.base
import
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
@
ray
.
remote
class
TestActor
(
Worker
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
@
register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
,
blocking
=
False
)
def
foo
(
self
,
wait_time
):
time
.
sleep
(
wait_time
)
sys
.
exit
(
1
)
if
__name__
==
"__main__"
:
wait_time
=
int
(
os
.
getenv
(
"WAIT_TIME"
,
"10"
))
ray
.
init
()
# test single-node-no-partition
print
(
"test single-node-no-partition"
)
resource_pool
=
RayResourcePool
([
2
],
use_gpu
=
False
)
class_with_args
=
RayClassWithInitArgs
(
cls
=
TestActor
)
print
(
"create worker group"
)
wg
=
RayWorkerGroup
(
resource_pool
,
class_with_args
,
name_prefix
=
"test"
)
wg
.
start_worker_aliveness_check
(
1
)
time
.
sleep
(
1
)
print
(
time
.
time
(),
"start foo"
)
_
=
wg
.
foo
(
wait_time
)
print
(
"foo started"
)
print
(
time
.
time
(),
f
"wait 6x wait time
{
wait_time
*
6
}
to let signal returned to process but still not exceed process wait time"
,
)
time
.
sleep
(
wait_time
*
6
)
ray
.
shutdown
()
tests/single_controller/detached_worker/README.md
0 → 100644
View file @
7f6cc211
# Detached Worker
## How to run (Only on a single node)
-
Start a local ray cluster:
```
bash
ray start
--head
--port
=
6379
```
-
Run the server
```
bash
python3 server.py
```
-
On another terminal, Run the client
```
bash
python3 client.py
```
tests/single_controller/detached_worker/client.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In client, we can get the server handler and send RPC request
"""
import
ray
import
torch
from
server
import
Trainer
from
tensordict
import
TensorDict
from
verl
import
DataProto
from
verl.single_controller.ray
import
RayClassWithInitArgs
from
verl.single_controller.ray.megatron
import
NVMegatronRayWorkerGroup
def
compute_position_id_with_mask
(
mask
):
return
torch
.
clip
(
torch
.
cumsum
(
mask
,
dim
=-
1
)
-
1
,
min
=
0
,
max
=
None
)
if
__name__
==
"__main__"
:
ray
.
init
(
address
=
"auto"
,
namespace
=
"verl"
)
# get the worker group using names
worker_names
=
[
"trainerTrainer_0:0"
,
"trainerTrainer_0:1"
]
cls_with_init_args
=
RayClassWithInitArgs
(
cls
=
Trainer
)
worker_group
=
NVMegatronRayWorkerGroup
.
from_detached
(
worker_names
=
worker_names
,
ray_cls_with_init
=
cls_with_init_args
)
batch_size
=
16
sequence_length
=
1024
# give Trainer some data to train
input_ids
=
torch
.
randint
(
low
=
0
,
high
=
256
,
size
=
(
batch_size
,
sequence_length
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
attention_mask
=
torch
.
ones_like
(
input_ids
)
position_ids
=
compute_position_id_with_mask
(
attention_mask
)
data
=
DataProto
(
batch
=
TensorDict
(
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"position_ids"
:
position_ids
},
batch_size
=
batch_size
,
),
meta_info
=
{},
)
output
=
worker_group
.
train_model
(
data
)
print
(
output
)
tests/single_controller/detached_worker/run.sh
0 → 100644
View file @
7f6cc211
#!/bin/bash
ray start
--head
--port
=
6379
python3 server.py
python3 client.py
ray stop
--force
\ No newline at end of file
tests/single_controller/detached_worker/server.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Server starts a Trainer. Client sends data to the server to train.
"""
import
os
os
.
environ
[
"MEGATRON_USE_CUDA_TIMER"
]
=
"0"
os
.
environ
[
"MEGATRON_START_PROCESS_TIMER"
]
=
"False"
os
.
environ
[
"NCCL_DEBUG"
]
=
"WARN"
import
ray
import
torch
from
megatron.core
import
parallel_state
as
mpu
from
megatron.core
import
tensor_parallel
from
megatron.core.models.gpt.gpt_model
import
ModelType
from
omegaconf
import
OmegaConf
from
tensordict
import
TensorDict
from
torch
import
nn
from
transformers
import
LlamaConfig
from
verl
import
DataProto
from
verl.models.llama.megatron
import
ParallelLlamaForCausalLMRmPadPP
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.base.megatron.worker
import
MegatronWorker
from
verl.single_controller.ray
import
RayClassWithInitArgs
,
RayResourcePool
from
verl.single_controller.ray.megatron
import
NVMegatronRayWorkerGroup
from
verl.utils.megatron.optimizer
import
get_megatron_optimizer
from
verl.utils.megatron_utils
import
get_model
,
init_megatron_optim_config
,
mcore_model_parallel_config
@
ray
.
remote
class
Trainer
(
MegatronWorker
):
def
__init__
(
self
):
super
().
__init__
()
if
not
torch
.
distributed
.
is_initialized
():
rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
rank
)
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
=
2
,
pipeline_model_parallel_size
=
1
,
virtual_pipeline_model_parallel_size
=
None
,
pipeline_model_parallel_split_rank
=
None
,
use_sharp
=
False
,
context_parallel_size
=
1
,
expert_model_parallel_size
=
1
,
nccl_communicator_config_path
=
None
,
)
tensor_parallel
.
model_parallel_cuda_manual_seed
(
10
)
@
register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
def
init_model
(
self
):
actor_model_config
=
LlamaConfig
(
vocab_size
=
256
,
hidden_size
=
2048
,
intermediate_size
=
5504
,
num_hidden_layers
=
24
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
)
megatron_config
=
mcore_model_parallel_config
(
sequence_parallel
=
True
,
params_dtype
=
torch
.
bfloat16
)
self
.
megatron_config
=
megatron_config
def
megatron_actor_model_provider
(
pre_process
,
post_process
):
# vpp is not supported yet because it will hang for some reason. Need debugging
# this_megatron_config = copy.deepcopy(megatron_config)
# this_megatron_config.virtual_pipeline_model_parallel_rank = vpp_rank
parallel_model
=
ParallelLlamaForCausalLMRmPadPP
(
config
=
actor_model_config
,
megatron_config
=
megatron_config
,
pre_process
=
pre_process
,
post_process
=
post_process
,
)
parallel_model
.
cuda
()
return
parallel_model
actor_module
=
get_model
(
model_provider_func
=
megatron_actor_model_provider
,
model_type
=
ModelType
.
encoder_or_decoder
,
wrap_with_ddp
=
True
,
)
actor_module
=
nn
.
ModuleList
(
actor_module
)
optim_config
=
OmegaConf
.
create
({
"lr"
:
1e-6
,
"clip_grad"
:
1.0
})
optim_config
=
init_megatron_optim_config
(
optim_config
)
self
.
optimizer_config
=
optim_config
actor_optimizer
=
get_megatron_optimizer
(
model
=
actor_module
,
config
=
optim_config
)
self
.
model
=
actor_module
[
0
]
self
.
optimizer
=
actor_optimizer
@
register
(
dispatch_mode
=
Dispatch
.
MEGATRON_COMPUTE_PROTO
)
def
train_model
(
self
,
data
:
DataProto
)
->
DataProto
:
input_ids
=
data
.
batch
[
"input_ids"
]
attention_mask
=
data
.
batch
[
"attention_mask"
]
position_ids
=
data
.
batch
[
"position_ids"
]
self
.
optimizer
.
zero_grad
()
self
.
model
.
zero_grad_buffer
(
zero_buffer
=
(
not
self
.
optimizer_config
.
use_distributed_optimizer
)
)
# use use_contiguous_buffers_in_local_ddp and no overlap_dp_param_comm
# update for 1 iteration
output
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
).
logits
output
.
mean
().
backward
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
self
.
optimizer
.
step
(
self
.
megatron_config
,
self
.
megatron_config
.
timers
)
return
DataProto
(
batch
=
TensorDict
({
"loss"
:
output
.
detach
()},
batch_size
=
output
.
shape
[
0
]))
if
__name__
==
"__main__"
:
ray
.
init
(
address
=
"auto"
,
namespace
=
"verl"
)
resource_pool
=
RayResourcePool
(
process_on_nodes
=
[
2
],
detached
=
True
)
cls_with_init_args
=
RayClassWithInitArgs
(
cls
=
Trainer
)
worker_group
=
NVMegatronRayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
cls_with_init_args
,
name_prefix
=
"trainer"
,
detached
=
True
,
)
worker_group
.
init_model
()
worker_names
=
worker_group
.
worker_names
print
(
worker_names
)
tests/single_controller/test_auto_padding_on_cpu.py
0 → 100644
View file @
7f6cc211
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
ray
import
torch
from
verl
import
DataProto
from
verl.protocol
import
DataProtoConfig
from
verl.single_controller.base
import
Worker
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.ray.base
import
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
# or set env var VERL_AUTO_PADDING = "1" / "true"
DataProtoConfig
.
auto_padding
=
True
@
ray
.
remote
class
Actor
(
Worker
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
add
(
self
,
data
:
DataProto
):
data
.
batch
[
"a"
]
+=
self
.
rank
return
data
def
test_auto_padding
():
ray
.
init
(
num_cpus
=
100
)
chunk_size
=
4
actor_cls
=
RayClassWithInitArgs
(
cls
=
Actor
)
resource_pool
=
RayResourcePool
(
process_on_nodes
=
[
chunk_size
],
use_gpu
=
False
)
actor_wg
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
actor_cls
)
# test locally first
for
test_size
in
range
(
4
,
20
):
local_data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
test_size
)},
{
"na"
:
np
.
zeros
(
test_size
,
dtype
=
object
)})
# print(f"before padding, local_data = {local_data}")
padding_size
=
(
chunk_size
-
(
test_size
%
chunk_size
))
if
(
test_size
%
chunk_size
>
0
)
else
0
local_data
.
padding
(
padding_size
)
# print(f"after padding, local_data = {local_data}")
assert
len
(
local_data
)
==
len
(
local_data
)
+
len
(
local_data
)
%
chunk_size
,
(
f
"expecting padded length to be
{
len
(
local_data
)
+
len
(
local_data
)
%
chunk_size
}
, but got
{
len
(
local_data
)
}
"
)
chunked
=
local_data
.
chunk
(
chunk_size
)
assert
len
(
chunked
)
==
chunk_size
,
f
"during test_size =
{
test_size
}
, expecting
{
chunk_size
}
, got
{
chunked
}
"
for
dp
in
chunked
:
assert
len
(
dp
)
==
test_size
//
chunk_size
+
bool
(
test_size
%
chunk_size
),
(
f
"test size =
{
test_size
}
, expecting dp to be length of "
f
"
{
test_size
//
chunk_size
+
bool
(
test_size
%
chunk_size
)
}
, but got
{
len
(
dp
)
}
:
{
dp
}
{
chunked
}
"
)
# test with RayWorkerGroup method decorated as dispatch_mode=Dispatch.DP_COMPUTE_PROTO
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
10
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
10
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
10
,
"Failed in args split and padding."
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
10
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
10
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
=
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
10
,
"Failed in kwargs split and padding."
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
1
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
1
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
1
,
"Failed in args split and padding."
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
1
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
1
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
=
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
1
,
"Failed in kwargs split and padding."
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
8
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
8
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
8
,
"Failed in args split and padding."
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
8
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
8
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
=
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
8
,
"Failed in kwargs split and padding."
# test data proto specific config
DataProtoConfig
.
auto_padding
=
False
data
=
DataProto
.
from_dict
(
{
"a"
:
torch
.
zeros
(
10
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
10
)],
dtype
=
object
)},
auto_padding
=
True
)
output
=
actor_wg
.
add
(
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
10
,
"Failed in args split and padding."
data
=
DataProto
.
from_dict
(
{
"a"
:
torch
.
zeros
(
10
)},
{
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
10
)],
dtype
=
object
)},
auto_padding
=
True
)
output
=
actor_wg
.
add
(
data
=
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
10
,
"Failed in kwargs split and padding."
data
=
DataProto
.
from_single_dict
(
{
"a"
:
torch
.
zeros
(
1
),
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
1
)],
dtype
=
object
)},
auto_padding
=
True
)
output
=
actor_wg
.
add
(
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
1
,
"Failed in args split and padding."
data
=
DataProto
.
from_single_dict
(
{
"a"
:
torch
.
zeros
(
1
),
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
1
)],
dtype
=
object
)},
auto_padding
=
True
)
output
=
actor_wg
.
add
(
data
=
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
1
,
"Failed in kwargs split and padding."
data
=
DataProto
.
from_single_dict
({
"a"
:
torch
.
zeros
(
8
),
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
8
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
8
,
"Failed in args split and padding."
data
=
DataProto
.
from_single_dict
({
"a"
:
torch
.
zeros
(
8
),
"na"
:
np
.
array
([
str
(
i
)
for
i
in
range
(
8
)],
dtype
=
object
)})
output
=
actor_wg
.
add
(
data
=
data
)
print
(
output
.
batch
[
"a"
])
assert
len
(
output
)
==
8
,
"Failed in kwargs split and padding."
ray
.
shutdown
()
if
__name__
==
"__main__"
:
test_auto_padding
()
tests/single_controller/test_colocated_workers.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
ray
from
verl
import
DataProto
from
verl.single_controller.base
import
Worker
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.ray.base
import
(
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
,
create_colocated_worker_cls
,
)
@
ray
.
remote
class
Actor
(
Worker
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
add
(
self
,
data
:
DataProto
):
data
.
batch
[
"a"
]
+=
self
.
rank
return
data
@
ray
.
remote
class
Critic
(
Worker
):
def
__init__
(
self
,
config
)
->
None
:
super
().
__init__
()
self
.
config
=
config
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
async
def
sub
(
self
,
data
:
DataProto
):
data
.
batch
[
"a"
]
-=
self
.
config
[
"b"
]
return
data
def
test_colocated_workers
():
ray
.
init
()
import
torch
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
10
)})
# create separate workers on the same resource pool
actor_cls
=
RayClassWithInitArgs
(
cls
=
Actor
)
critic_cls
=
RayClassWithInitArgs
(
cls
=
Critic
,
config
=
{
"b"
:
10
})
resource_pool
=
RayResourcePool
(
process_on_nodes
=
[
2
])
actor_wg
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
actor_cls
)
critic_wg
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
critic_cls
)
expected_actor_output
=
actor_wg
.
add
(
data
)
expected_critic_output
=
critic_wg
.
sub
(
data
)
# create colocated workers
cls_dict
=
{
"actor"
:
actor_cls
,
"critic"
:
critic_cls
}
ray_cls_with_init
=
create_colocated_worker_cls
(
cls_dict
)
wg_dict
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
ray_cls_with_init
)
spawn_wg
=
wg_dict
.
spawn
(
prefix_set
=
cls_dict
.
keys
())
colocated_actor_wg
=
spawn_wg
[
"actor"
]
colocated_critic_wg
=
spawn_wg
[
"critic"
]
actor_output
=
colocated_actor_wg
.
add
(
data
)
critic_output
=
colocated_critic_wg
.
sub
(
data
)
torch
.
testing
.
assert_close
(
expected_actor_output
.
batch
,
actor_output
.
batch
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
expected_critic_output
.
batch
,
critic_output
.
batch
,
atol
=
0
,
rtol
=
0
)
ray
.
shutdown
()
tests/single_controller/test_colocated_workers_fused.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
ray
from
verl
import
DataProto
from
verl.single_controller.base
import
Worker
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.ray.base
import
(
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
,
create_colocated_worker_cls_fused
,
)
@
ray
.
remote
class
Actor
(
Worker
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
add
(
self
,
data
:
DataProto
):
data
.
batch
[
"a"
]
+=
self
.
rank
return
data
@
ray
.
remote
class
Critic
(
Worker
):
def
__init__
(
self
,
config
)
->
None
:
super
().
__init__
()
self
.
config
=
config
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
sub
(
self
,
data
:
DataProto
):
data
.
batch
[
"a"
]
-=
self
.
config
[
"b"
]
return
data
def
test_colocated_workers_fused
():
ray
.
init
()
import
torch
data
=
DataProto
.
from_dict
({
"a"
:
torch
.
zeros
(
10
)})
# create separate workers on the same resource pool
actor_cls
=
RayClassWithInitArgs
(
cls
=
Actor
)
critic_cls
=
RayClassWithInitArgs
(
cls
=
Critic
,
config
=
{
"b"
:
10
})
resource_pool
=
RayResourcePool
(
process_on_nodes
=
[
2
])
actor_wg
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
actor_cls
)
critic_wg
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
critic_cls
)
expected_actor_output
=
actor_wg
.
add
(
data
)
expected_critic_output
=
critic_wg
.
sub
(
data
)
# create colocated workers
cls_dict
=
{
"actor"
:
actor_cls
,
"critic"
:
critic_cls
}
ray_cls_with_init
=
create_colocated_worker_cls_fused
(
cls_dict
)
wg_dict
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
ray_cls_with_init
)
spawn_wg
=
wg_dict
.
spawn
(
prefix_set
=
cls_dict
.
keys
())
colocated_actor_wg
=
spawn_wg
[
"actor"
]
colocated_critic_wg
=
spawn_wg
[
"critic"
]
actor_output
=
colocated_actor_wg
.
add
(
data
)
critic_output
=
colocated_critic_wg
.
sub
(
data
)
torch
.
testing
.
assert_close
(
expected_actor_output
.
batch
,
actor_output
.
batch
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
expected_critic_output
.
batch
,
critic_output
.
batch
,
atol
=
0
,
rtol
=
0
)
ray
.
shutdown
()
tests/single_controller/test_data_transfer.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
In this test, we instantiate a data parallel worker with 8 GPUs
"""
import
ray
import
tensordict
import
torch
from
codetiming
import
Timer
from
torch
import
distributed
as
dist
from
verl
import
DataProto
from
verl.single_controller.base
import
Worker
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.ray
import
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
from
verl.utils.ray_utils
import
parallel_put
@
ray
.
remote
class
DummyWorker
(
Worker
):
def
__init__
(
self
):
super
().
__init__
()
dist
.
init_process_group
()
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE
,
blocking
=
False
)
def
do_nothing
(
self
,
data
):
for
key
in
data
.
batch
.
keys
():
data
.
batch
[
key
]
+=
1
if
tensordict
.
__version__
>=
"0.5.0"
:
data
.
batch
=
data
.
batch
.
consolidate
()
return
data
def
test_data_transfer
():
ray
.
init
()
# construct resource pool
resource_pool
=
RayResourcePool
([
8
])
cls_with_init
=
RayClassWithInitArgs
(
cls
=
DummyWorker
)
# construct worker group
wg
=
RayWorkerGroup
(
resource_pool
,
cls_with_init
)
# this is real dataset size
batch_size
=
4096
seqlen
=
32768
data_dict
=
{}
for
i
in
range
(
2
):
data_dict
[
str
(
i
)]
=
torch
.
randint
(
0
,
10000
,
(
batch_size
,
seqlen
))
data
=
DataProto
.
from_dict
(
tensors
=
data_dict
)
print
(
data
)
# we manually split data here and send to each worker
data_list
=
data
.
chunk
(
wg
.
world_size
)
for
i
in
range
(
wg
.
world_size
):
# consolidate is necessary
if
tensordict
.
__version__
>=
"0.5.0"
:
data_list
[
i
].
batch
=
data_list
[
i
].
batch
.
consolidate
()
with
Timer
(
name
=
"ray.pickle"
,
initial_text
=
True
):
for
i
in
range
(
wg
.
world_size
):
ray
.
cloudpickle
.
pickle
.
dumps
(
data_list
[
i
])
with
Timer
(
name
=
"raw.pickle"
,
initial_text
=
True
):
import
pickle
for
i
in
range
(
wg
.
world_size
):
pickle
.
dumps
(
data_list
[
i
])
# we put in advance
with
Timer
(
name
=
"put"
,
initial_text
=
True
):
# takes around 40 seconds
data_list_ref
=
parallel_put
(
data_list
)
# for i in range(wg.world_size):
# data_list[i] = ray.put(data_list[i])
with
Timer
(
name
=
"launch"
,
initial_text
=
True
):
output_ref
=
wg
.
do_nothing
(
data_list_ref
)
with
Timer
(
name
=
"get"
,
initial_text
=
True
):
# takes around 40 seconds
output_lst
=
ray
.
get
(
output_ref
)
for
input_data
,
output_data
in
zip
(
data_list
,
output_lst
,
strict
=
True
):
for
key
in
input_data
.
batch
.
keys
():
assert
torch
.
all
(
torch
.
eq
(
input_data
.
batch
[
key
]
+
1
,
output_data
.
batch
[
key
])),
(
input_data
.
batch
[
key
],
output_data
.
batch
[
key
],
key
,
)
ray
.
shutdown
()
tests/single_controller/test_decorator_on_cpu.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
time
import
pytest
import
ray
import
torch
from
tensordict
import
TensorDict
from
verl.protocol
import
DataProto
,
DataProtoFuture
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.base.worker
import
Worker
from
verl.single_controller.ray
import
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
# Pytest fixture for Ray setup/teardown
@
pytest
.
fixture
def
ray_init_shutdown
():
ray
.
init
(
num_cpus
=
100
)
yield
ray
.
shutdown
()
# Define a simple worker for testing
@
ray
.
remote
class
DecoratorTestWorker
(
Worker
):
def
__init__
(
self
,
initial_value
=
0
):
super
().
__init__
()
self
.
value
=
initial_value
# Simulate some setup if needed
time
.
sleep
(
0.1
)
# Ensure worker init completes
# Test method for synchronous DP compute (default behavior)
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
)
def
dp_compute
(
self
,
data
:
DataProto
)
->
DataProto
:
time
.
sleep
(
0.1
)
# Simulate work
rank_value
=
torch
.
tensor
(
self
.
rank
,
device
=
data
.
batch
[
"input"
].
device
,
dtype
=
data
.
batch
[
"input"
].
dtype
)
data
.
batch
[
"output"
]
=
data
.
batch
[
"input"
]
+
self
.
value
+
rank_value
return
data
# Test async def method with DP compute (default behavior)
@
register
(
dispatch_mode
=
Dispatch
.
DP_COMPUTE_PROTO
,
blocking
=
False
)
async
def
async_dp_compute
(
self
,
data
:
DataProto
)
->
DataProto
:
# Simulate async work
await
asyncio
.
sleep
(
0.1
)
# Simulate async work
rank_value
=
torch
.
tensor
(
self
.
rank
,
device
=
data
.
batch
[
"input"
].
device
,
dtype
=
data
.
batch
[
"input"
].
dtype
)
data
.
batch
[
"output_async"
]
=
data
.
batch
[
"input"
]
*
2
+
self
.
value
+
rank_value
return
data
# Test function for synchronous DP compute
def
test_decorator_dp_compute
(
ray_init_shutdown
):
"""
Tests the default behavior of a synchronous decorated method with DP_COMPUTE_PROTO.
Verifies the result correctness.
"""
num_workers
=
2
resource_pool
=
RayResourcePool
([
num_workers
],
use_gpu
=
False
,
max_colocate_count
=
1
)
# Use CPU for simplicity
cls_with_args
=
RayClassWithInitArgs
(
cls
=
DecoratorTestWorker
,
initial_value
=
10
)
worker_group
=
RayWorkerGroup
(
resource_pool
,
cls_with_args
,
name_prefix
=
f
"decorator_test_sync_dp_
{
int
(
time
.
time
())
}
"
)
# Prepare input data (size 4, for 2 workers)
input_tensor
=
torch
.
arange
(
4
,
dtype
=
torch
.
float32
)
data
=
DataProto
(
batch
=
TensorDict
({
"input"
:
input_tensor
},
batch_size
=
[
4
]))
# Call the decorated method
output
=
worker_group
.
dp_compute
(
data
)
# Assert the result correctness
assert
isinstance
(
output
,
DataProto
),
"Expected DataProto result"
assert
"output"
in
output
.
batch
.
keys
()
assert
len
(
output
)
==
len
(
data
),
"Output length should match input length"
# Expected output calculation for DP_COMPUTE_PROTO with 2 workers
# Worker 0 gets data[0:2], Worker 1 gets data[2:4]
# Worker 0 adds initial_value(10) + rank(0) = 10
# Worker 1 adds initial_value(10) + rank(1) = 11
expected_output_part1
=
torch
.
tensor
([
0
,
1
],
dtype
=
torch
.
float32
)
+
10
+
0
expected_output_part2
=
torch
.
tensor
([
2
,
3
],
dtype
=
torch
.
float32
)
+
10
+
1
expected_output
=
torch
.
cat
([
expected_output_part1
,
expected_output_part2
])
torch
.
testing
.
assert_close
(
output
.
batch
[
"output"
],
expected_output
,
msg
=
"Sync DP compute output data mismatch"
)
# Test function for async def method with DP compute
def
test_decorator_async_function
(
ray_init_shutdown
):
"""
Tests the decorator with an `async def` method using DP_COMPUTE_PROTO.
Verifies that the call returns a future and the result is correct after .get().
"""
num_workers
=
2
resource_pool
=
RayResourcePool
([
num_workers
],
use_gpu
=
False
,
max_colocate_count
=
1
)
cls_with_args
=
RayClassWithInitArgs
(
cls
=
DecoratorTestWorker
,
initial_value
=
5
)
worker_group
=
RayWorkerGroup
(
resource_pool
,
cls_with_args
,
name_prefix
=
f
"decorator_test_async_dp_
{
int
(
time
.
time
())
}
"
)
# Prepare input data (size 4, for 2 workers)
input_tensor
=
torch
.
arange
(
4
,
dtype
=
torch
.
float32
)
data
=
DataProto
(
batch
=
TensorDict
({
"input"
:
input_tensor
},
batch_size
=
[
4
]))
# Call the async decorated method - this should return a future
future_output
:
DataProtoFuture
=
worker_group
.
async_dp_compute
(
data
)
# Assert that the call returned a future
assert
isinstance
(
future_output
,
DataProtoFuture
),
"Expected DataProtoFuture for async def call"
# Get the result (this should block)
result_data
=
future_output
.
get
()
# Assert the result correctness
assert
isinstance
(
result_data
,
DataProto
)
assert
"output_async"
in
result_data
.
batch
.
keys
()
assert
len
(
result_data
)
==
len
(
data
),
"Output length should match input length"
# Expected output calculation for DP_COMPUTE_PROTO with 2 workers
# Worker 0 gets data[0:2], Worker 1 gets data[2:4]
# Worker 0 calculates: input * 2 + initial_value(5) + rank(0)
# Worker 1 calculates: input * 2 + initial_value(5) + rank(1)
expected_output_part1
=
(
torch
.
tensor
([
0
,
1
],
dtype
=
torch
.
float32
)
*
2
)
+
5
+
0
expected_output_part2
=
(
torch
.
tensor
([
2
,
3
],
dtype
=
torch
.
float32
)
*
2
)
+
5
+
1
expected_output
=
torch
.
cat
([
expected_output_part1
,
expected_output_part2
])
torch
.
testing
.
assert_close
(
result_data
.
batch
[
"output_async"
],
expected_output
,
msg
=
"Async DP compute output data mismatch"
)
tests/single_controller/test_driverfunc_to_worker.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
ray
import
torch
from
tensordict
import
TensorDict
from
verl
import
DataProto
from
verl.single_controller.base.worker
import
Worker
from
verl.single_controller.ray
import
RayWorkerGroup
from
verl.single_controller.ray.base
import
RayClassWithInitArgs
,
RayResourcePool
os
.
environ
[
"RAY_DEDUP_LOGS"
]
=
"0"
os
.
environ
[
"NCCL_DEBUG"
]
=
"WARN"
@
ray
.
remote
class
ModelActor
(
Worker
):
def
__init__
(
self
):
pass
class
HackSelf
:
def
__init__
(
self
):
pass
def
get_aux_metrics
(
self
,
test_proto
):
sequence_ids
=
test_proto
.
batch
[
"sequence_ids"
]
decode_count
=
[]
for
i
in
range
(
sequence_ids
.
size
(
0
)):
decode_count
.
append
(
len
(
sequence_ids
[
i
].
tolist
()))
ret_proto
=
DataProto
(
batch
=
TensorDict
(
{
"sequence_ids"
:
sequence_ids
,
"decode_count"
:
torch
.
tensor
(
decode_count
)},
batch_size
=
sequence_ids
.
size
(
0
)
)
)
return
ret_proto
def
test
():
# construct model
ray
.
init
()
# create 2 workers, each hold a GPU
resource_pool
=
RayResourcePool
([
2
],
use_gpu
=
True
,
name_prefix
=
"a"
)
class_with_args
=
RayClassWithInitArgs
(
cls
=
ModelActor
)
shard_wg
=
RayWorkerGroup
(
resource_pool
,
class_with_args
)
test_bs
=
8
test_proto
=
DataProto
(
TensorDict
(
{
"sequence_ids"
:
torch
.
ones
([
test_bs
,
2048
],
dtype
=
torch
.
int64
),
},
batch_size
=
test_bs
,
),
meta_info
=
{
"query_length"
:
1536
},
)
# Sharding among different ranks
ret_proto1
=
shard_wg
.
execute_with_func_generator
(
get_aux_metrics
,
test_proto
)
# compare execute on driver
hs
=
HackSelf
()
ret_proto2
=
get_aux_metrics
(
hs
,
test_proto
)
torch
.
testing
.
assert_close
(
ret_proto1
.
batch
[
"decode_count"
],
ret_proto2
.
batch
[
"decode_count"
])
ray
.
shutdown
()
tests/single_controller/test_fused_workers_on_cpu.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
ray
from
verl.single_controller.base
import
Worker
from
verl.single_controller.base.decorator
import
Dispatch
,
register
from
verl.single_controller.ray.base
import
(
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
,
create_colocated_worker_raw_cls
,
)
@
ray
.
remote
class
Actor
(
Worker
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
@
register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
def
add
(
self
,
x
):
x
+=
self
.
rank
return
x
@
ray
.
remote
class
Critic
(
Worker
):
def
__init__
(
self
,
val
)
->
None
:
super
().
__init__
()
self
.
val
=
val
@
register
(
dispatch_mode
=
Dispatch
.
ALL_TO_ALL
)
def
sub
(
self
,
x
):
x
-=
self
.
val
return
x
actor_cls
=
RayClassWithInitArgs
(
cls
=
Actor
)
critic_cls
=
RayClassWithInitArgs
(
cls
=
Critic
,
val
=
10
)
cls_dict
=
{
"actor"
:
actor_cls
,
"critic"
:
critic_cls
}
FusedBaseClass
=
create_colocated_worker_raw_cls
(
cls_dict
)
@
ray
.
remote
class
HybridWorker
(
FusedBaseClass
):
@
register
(
dispatch_mode
=
Dispatch
.
ONE_TO_ALL
)
def
foo
(
self
,
x
):
return
self
.
critic
.
sub
(
self
.
actor
.
add
(
x
))
def
test_fused_workers
():
ray
.
init
(
num_cpus
=
100
)
# create separate workers on the same resource pool
process_on_nodes
=
[
2
]
resource_pool
=
RayResourcePool
(
process_on_nodes
=
process_on_nodes
,
use_gpu
=
False
)
# create colocated workers
hybrid_cls_with_init
=
RayClassWithInitArgs
(
cls
=
HybridWorker
)
hybrid_cls_with_init
.
fused_worker_used
=
True
fused_wg
=
RayWorkerGroup
(
resource_pool
=
resource_pool
,
ray_cls_with_init
=
hybrid_cls_with_init
)
fused_wg
.
fuse
(
cls_dict
.
keys
())
x
=
fused_wg
.
actor
.
add
(
0.1
)
print
(
x
)
y
=
fused_wg
.
critic
.
sub
(
x
)
print
(
y
)
z
=
fused_wg
.
foo
(
0.1
)
print
(
z
)
for
i
,
j
in
zip
(
y
,
z
,
strict
=
True
):
assert
i
==
j
ray
.
shutdown
()
if
__name__
==
"__main__"
:
test_fused_workers
()
tests/single_controller/test_high_level_scheduling_api.py
0 → 100644
View file @
7f6cc211
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
import
ray
from
verl.single_controller.base.worker
import
Worker
from
verl.single_controller.ray.base
import
RayClassWithInitArgs
,
RayResourcePool
,
RayWorkerGroup
,
merge_resource_pool
@
ray
.
remote
class
TestActor
(
Worker
):
# TODO: pass *args and **kwargs is bug prone and not very convincing
def
__init__
(
self
,
cuda_visible_devices
=
None
)
->
None
:
super
().
__init__
(
cuda_visible_devices
)
def
get_node_id
(
self
):
return
ray
.
get_runtime_context
().
get_node_id
()
def
test
():
ray
.
init
()
# test single-node-no-partition
print
(
"test single-node-no-partition"
)
resource_pool
=
RayResourcePool
([
8
],
use_gpu
=
True
)
class_with_args
=
RayClassWithInitArgs
(
cls
=
TestActor
)
print
(
"create actor worker group"
)
actor_wg
=
RayWorkerGroup
(
resource_pool
,
class_with_args
,
name_prefix
=
"high_level_api_actor"
)
print
(
"create critic worker group"
)
critic_wg
=
RayWorkerGroup
(
resource_pool
,
class_with_args
,
name_prefix
=
"hight_level_api_critic"
)
print
(
"create rm worker group"
)
rm_wg
=
RayWorkerGroup
(
resource_pool
,
class_with_args
,
name_prefix
=
"high_level_api_rm"
)
print
(
"create ref worker group"
)
ref_wg
=
RayWorkerGroup
(
resource_pool
,
class_with_args
,
name_prefix
=
"high_level_api_ref"
)
assert
actor_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
8
)]
assert
critic_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
8
)]
assert
rm_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
8
)]
assert
ref_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
8
)]
del
actor_wg
del
critic_wg
del
rm_wg
del
ref_wg
[
ray
.
util
.
remove_placement_group
(
pg
)
for
pg
in
resource_pool
.
get_placement_groups
()]
print
(
"wait 5s to remove placemeng_group"
)
time
.
sleep
(
5
)
# test single-node-multi-partition
print
(
"test single-node-multi-partition"
)
rm_resource_pool
=
RayResourcePool
([
4
],
use_gpu
=
True
,
name_prefix
=
"rm"
)
ref_resource_pool
=
RayResourcePool
([
4
],
use_gpu
=
True
,
name_prefix
=
"ref"
)
total_resource_pool
=
merge_resource_pool
(
rm_resource_pool
,
ref_resource_pool
)
assert
rm_resource_pool
.
world_size
==
4
assert
ref_resource_pool
.
world_size
==
4
assert
total_resource_pool
.
world_size
==
8
actor_wg
=
RayWorkerGroup
(
total_resource_pool
,
class_with_args
,
name_prefix
=
"high_level_api_actor"
)
critic_wg
=
RayWorkerGroup
(
total_resource_pool
,
class_with_args
,
name_prefix
=
"high_level_api_critic"
)
rm_wg
=
RayWorkerGroup
(
rm_resource_pool
,
class_with_args
,
name_prefix
=
"high_level_api_rm"
)
ref_wg
=
RayWorkerGroup
(
ref_resource_pool
,
class_with_args
,
name_prefix
=
"high_level_api_ref"
)
assert
actor_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
8
)]
assert
critic_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
8
)]
assert
rm_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
4
)]
assert
ref_wg
.
execute_all_sync
(
"get_cuda_visible_devices"
)
==
[
str
(
i
)
for
i
in
range
(
4
,
8
)]
ray
.
shutdown
()
Prev
1
…
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment