Unverified Commit 31b9f19e authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[RL] support weight update with DP attention (#11669)

parent 547003bd
...@@ -146,6 +146,13 @@ class _Communicator(Generic[T]): ...@@ -146,6 +146,13 @@ class _Communicator(Generic[T]):
if len(self._result_values) == self._fan_out: if len(self._result_values) == self._fan_out:
self._result_event.set() self._result_event.set()
@staticmethod
def merge_results(results):
all_success = all([r.success for r in results])
all_message = [r.message for r in results]
all_message = " | ".join(all_message)
return all_success, all_message
class TokenizerCommunicatorMixin: class TokenizerCommunicatorMixin:
"""Mixin class for TokenizerManager to handle communication with the scheduler.""" """Mixin class for TokenizerManager to handle communication with the scheduler."""
...@@ -358,10 +365,11 @@ class TokenizerCommunicatorMixin: ...@@ -358,10 +365,11 @@ class TokenizerCommunicatorMixin:
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
self.auto_create_handle_loop() self.auto_create_handle_loop()
assert ( assert (
self.server_args.dp_size == 1 self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 for init parameter update group" ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message results = await self.init_weights_update_group_communicator(obj)
return _Communicator.merge_results(results)
async def destroy_weights_update_group( async def destroy_weights_update_group(
self, self,
...@@ -370,10 +378,11 @@ class TokenizerCommunicatorMixin: ...@@ -370,10 +378,11 @@ class TokenizerCommunicatorMixin:
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
self.auto_create_handle_loop() self.auto_create_handle_loop()
assert ( assert (
self.server_args.dp_size == 1 self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 for destroy parameter update group" ), "dp_size must be 1 or dp attention must be enabled for destroy parameter update group"
result = (await self.destroy_weights_update_group_communicator(obj))[0]
return result.success, result.message results = await self.destroy_weights_update_group_communicator(obj)
return _Communicator.merge_results(results)
async def update_weights_from_distributed( async def update_weights_from_distributed(
self: TokenizerManager, self: TokenizerManager,
...@@ -391,8 +400,8 @@ class TokenizerCommunicatorMixin: ...@@ -391,8 +400,8 @@ class TokenizerCommunicatorMixin:
# This means that weight sync # This means that weight sync
# cannot run while requests are in progress. # cannot run while requests are in progress.
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_distributed_communicator(obj))[0] results = await self.update_weights_from_distributed_communicator(obj)
return result.success, result.message return _Communicator.merge_results(results)
async def init_weights_send_group_for_remote_instance( async def init_weights_send_group_for_remote_instance(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment