Unverified Commit 90faf901 authored by BearBiscuit's avatar BearBiscuit Committed by GitHub
Browse files

[verl] Modify the update_weights func to align with verl's resharding (#5345)


Co-authored-by: default avatarChayenne <zhaochen20@outlook.com>
parent 177320a5
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import os import os
from typing import Dict, List, Literal, Optional, Tuple, Union from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -124,7 +124,7 @@ class VerlEngine: ...@@ -124,7 +124,7 @@ class VerlEngine:
def update_weights_from_tensor( def update_weights_from_tensor(
self, self,
named_tensors: List[Tuple[str, torch.Tensor]], named_tensors: Iterable[Tuple[str, torch.Tensor]],
load_format: Optional[str] = None, load_format: Optional[str] = None,
): ):
# Most naive implementation, can optimize a lot if it is bottleneck # Most naive implementation, can optimize a lot if it is bottleneck
...@@ -153,9 +153,12 @@ class VerlEngine: ...@@ -153,9 +153,12 @@ class VerlEngine:
) )
], ],
load_format=load_format, load_format=load_format,
flush_cache=tensor_index == len(named_tensors) - 1, flush_cache=False,
) )
if self._tp_rank == 0:
self._engine.tokenizer_manager.flush_cache()
def release_memory_occupation(self): def release_memory_occupation(self):
if self._tp_rank == 0: if self._tp_rank == 0:
self._engine.release_memory_occupation() self._engine.release_memory_occupation()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment