Commit 004a1ef4 authored by yangshj1's avatar yangshj1
Browse files

cpp fix

parent 5f308e68
......@@ -3324,6 +3324,16 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
def _recv_tensor_dict():
return self._recv_queue.get()
if self._recv_thread is None:
self._recv_stream = torch.cuda.Stream()
self._recv_event = torch.cuda.Event()
self._recv_thread = threading.Thread(target=self._tensor_dict_recv_thread, daemon=True, name="pp_recv_thread")
self._recv_thread.start()
intermediate_tensors = _recv_tensor_dict()
torch.cuda.current_stream().wait_event(self._recv_event)
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
......@@ -3392,6 +3402,9 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
if isinstance(model_output, IntermediateTensors):
residual_clone = model_output.tensors["residual"].clone()
hidden_states.tensors["residual"] = residual_clone
aux_hidden_states = None
# Broadcast PP output for external_launcher (torchrun)
......
......@@ -319,8 +319,6 @@ class Worker(WorkerBase):
event.synchronize()
def _send_tensor_dict():
# [waterliao mark]
# logger.info(f"[waterliao debug] 222 rank{self.local_rank} _send_tensor_dict intermediate_tensors:{intermediate_tensors}, residual.shape={intermediate_tensors['residual'].shape}")
get_pp_group().send_tensor_dict(
intermediate_tensors.tensors,
all_gather_group=get_tp_group(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment