Unverified Commit 31b9f19e authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[RL] support weight update with DP attention (#11669)

parent 547003bd
......@@ -146,6 +146,13 @@ class _Communicator(Generic[T]):
if len(self._result_values) == self._fan_out:
self._result_event.set()
@staticmethod
def merge_results(results):
all_success = all([r.success for r in results])
all_message = [r.message for r in results]
all_message = " | ".join(all_message)
return all_success, all_message
class TokenizerCommunicatorMixin:
"""Mixin class for TokenizerManager to handle communication with the scheduler."""
......@@ -358,10 +365,11 @@ class TokenizerCommunicatorMixin:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
results = await self.init_weights_update_group_communicator(obj)
return _Communicator.merge_results(results)
async def destroy_weights_update_group(
self,
......@@ -370,10 +378,11 @@ class TokenizerCommunicatorMixin:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for destroy parameter update group"
result = (await self.destroy_weights_update_group_communicator(obj))[0]
return result.success, result.message
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for destroy parameter update group"
results = await self.destroy_weights_update_group_communicator(obj)
return _Communicator.merge_results(results)
async def update_weights_from_distributed(
self: TokenizerManager,
......@@ -391,8 +400,8 @@ class TokenizerCommunicatorMixin:
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message
results = await self.update_weights_from_distributed_communicator(obj)
return _Communicator.merge_results(results)
async def init_weights_send_group_for_remote_instance(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment