Unverified Commit 90faf901 authored by BearBiscuit's avatar BearBiscuit Committed by GitHub
Browse files

[verl] Modify the update_weights func to align with verl's resharding (#5345)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent 177320a5
......@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
import os
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -124,7 +124,7 @@ class VerlEngine:
def update_weights_from_tensor(
self,
named_tensors: List[Tuple[str, torch.Tensor]],
named_tensors: Iterable[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None,
):
# Most naive implementation, can optimize a lot if it is bottleneck
......@@ -153,9 +153,12 @@ class VerlEngine:
)
],
load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1,
flush_cache=False,
)
if self._tp_rank == 0:
self._engine.tokenizer_manager.flush_cache()
def release_memory_occupation(self):
if self._tp_rank == 0:
self._engine.release_memory_occupation()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment