Unverified Commit 0626f678 authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[RL] support update_weights_from_distributed with different group and multiple weights (#7292)

parent 09e699bb
...@@ -418,12 +418,21 @@ class Engine(EngineBase): ...@@ -418,12 +418,21 @@ class Engine(EngineBase):
self.tokenizer_manager.init_weights_update_group(obj, None) self.tokenizer_manager.init_weights_update_group(obj, None)
) )
def update_weights_from_distributed(self, name: str, dtype, shape): def update_weights_from_distributed(
self,
names: list[str],
dtypes: list[str],
shapes: list[list[int]],
group_name: str = "weight_update_group",
flush_cache: bool = True,
):
"""Update weights from distributed source.""" """Update weights from distributed source."""
obj = UpdateWeightsFromDistributedReqInput( obj = UpdateWeightsFromDistributedReqInput(
name=name, names=names,
dtype=dtype, dtypes=dtypes,
shape=shape, shapes=shapes,
group_name=group_name,
flush_cache=flush_cache,
) )
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
......
...@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput: ...@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput:
@dataclass @dataclass
class UpdateWeightsFromDistributedReqInput: class UpdateWeightsFromDistributedReqInput:
name: str names: List[str]
dtype: str dtypes: List[str]
shape: List[int] shapes: List[List[int]]
# The group name
group_name: str = "weight_update_group"
# Whether to flush the cache after updating weights
flush_cache: bool = True
@dataclass @dataclass
......
...@@ -2303,6 +2303,7 @@ class Scheduler( ...@@ -2303,6 +2303,7 @@ class Scheduler(
"""Update the online model parameter.""" """Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req) success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success: if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache() flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights" assert flush_cache_success, "Cache flush failed after updating weights"
else: else:
......
...@@ -259,7 +259,7 @@ class TpModelWorker: ...@@ -259,7 +259,7 @@ class TpModelWorker:
self, recv_req: UpdateWeightsFromDistributedReqInput self, recv_req: UpdateWeightsFromDistributedReqInput
): ):
success, message = self.model_runner.update_weights_from_distributed( success, message = self.model_runner.update_weights_from_distributed(
recv_req.name, recv_req.dtype, recv_req.shape recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
) )
return success, message return success, message
......
...@@ -225,6 +225,7 @@ class ModelRunner: ...@@ -225,6 +225,7 @@ class ModelRunner:
self.support_pp = ( self.support_pp = (
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
) )
self._model_update_group = {}
def initialize(self, min_per_gpu_memory: float): def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args server_args = self.server_args
...@@ -744,7 +745,7 @@ class ModelRunner: ...@@ -744,7 +745,7 @@ class ModelRunner:
) )
try: try:
self._model_update_group = init_custom_process_group( self._model_update_group[group_name] = init_custom_process_group(
backend=backend, backend=backend,
init_method=f"tcp://{master_address}:{master_port}", init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size, world_size=world_size,
...@@ -757,7 +758,7 @@ class ModelRunner: ...@@ -757,7 +758,7 @@ class ModelRunner:
logger.error(message) logger.error(message)
return False, message return False, message
def update_weights_from_distributed(self, name, dtype, shape): def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
""" """
Update specific parameter in the model weights online Update specific parameter in the model weights online
through `_model_update_group` process group. through `_model_update_group` process group.
...@@ -767,19 +768,34 @@ class ModelRunner: ...@@ -767,19 +768,34 @@ class ModelRunner:
dtype: the data type of the parameter to be updated. dtype: the data type of the parameter to be updated.
shape: the shape of the parameter to be updated. shape: the shape of the parameter to be updated.
""" """
assert group_name in self._model_update_group, (
f"Group {group_name} not in {list(self._model_update_group.keys())}. "
"Please call `init_weights_update_group` first."
)
try:
weights = []
handles = []
for name, dtype, shape in zip(names, dtypes, shapes):
target_dtype = ( target_dtype = (
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
) )
weight = torch.empty(shape, dtype=target_dtype, device=self.device)
handles.append(
torch.distributed.broadcast(
weight,
src=0,
group=self._model_update_group[group_name],
async_op=True,
)
)
weights.append((name, weight))
for handle in handles:
handle.wait()
assert ( self.model.load_weights(weights)
self._model_update_group is not None return True, f"Succeeded to update parameter online."
), "model update group must be initialized"
try:
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
self.model.load_weights([(name, weights)])
return True, f"Succeeded to update parameter {name} online."
except Exception as e: except Exception as e:
error_msg = ( error_msg = (
......
...@@ -294,20 +294,25 @@ def init_process_sgl( ...@@ -294,20 +294,25 @@ def init_process_sgl(
update_parameters.remove("lm_head.weight") update_parameters.remove("lm_head.weight")
# Get weights from the training engine and update the inference engine. # Get weights from the training engine and update the inference engine.
for parameter_name in update_parameters: names = [parameter_name for parameter_name in update_parameters]
dtypes = [torch.bfloat16 if backend == "Engine" else "bfloat16"] * len(names)
shapes = [state_dict_key_to_shape[parameter_name] for parameter_name in names]
if backend == "Engine": if backend == "Engine":
engine.update_weights_from_distributed( engine.update_weights_from_distributed(
parameter_name, names,
dtype=torch.bfloat16, dtypes=dtypes,
shape=state_dict_key_to_shape[parameter_name], shapes=shapes,
group_name="test_parameter_update_group",
) )
else: else:
requests.post( requests.post(
f"{url}/update_weights_from_distributed", f"{url}/update_weights_from_distributed",
json={ json={
"name": parameter_name, "names": names,
"dtype": "bfloat16", "dtypes": dtypes,
"shape": state_dict_key_to_shape[parameter_name], "shapes": shapes,
"group_name": "test_parameter_update_group",
}, },
) )
torch.cuda.synchronize() torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment