Unverified Commit 845420ac authored by 22quinn's avatar 22quinn Committed by GitHub
Browse files

[RLHF] Fix torch.dtype not serializable in example (#22158)


Signed-off-by: default avatar22quinn <33176974+22quinn@users.noreply.github.com>
parent e27d25a0
...@@ -126,7 +126,10 @@ for name, p in train_model.named_parameters(): ...@@ -126,7 +126,10 @@ for name, p in train_model.named_parameters():
# Synchronize the updated weights to the inference engine. # Synchronize the updated weights to the inference engine.
for name, p in train_model.named_parameters(): for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape)) dtype_name = str(p.dtype).split(".")[-1]
handle = llm.collective_rpc.remote(
"update_weight", args=(name, dtype_name, p.shape)
)
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle) ray.get(handle)
......
...@@ -45,7 +45,8 @@ class WorkerExtension: ...@@ -45,7 +45,8 @@ class WorkerExtension:
self.device, self.device,
) )
def update_weight(self, name, dtype, shape): def update_weight(self, name, dtype_name, shape):
dtype = getattr(torch, dtype_name)
weight = torch.empty(shape, dtype=dtype, device="cuda") weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast( self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream() weight, src=0, stream=torch.cuda.current_stream()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment