Unverified Commit cb432f17 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

saving hidden_states.clone() (#7705)

parent 1964c325
...@@ -436,8 +436,8 @@ class LogitsProcessor(nn.Module): ...@@ -436,8 +436,8 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn: if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata(hidden_states) logits_metadata.compute_dp_attention_metadata(hidden_states)
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
logits_metadata.gathered_buffer, torch.empty_like(logits_metadata.gathered_buffer),
hidden_states.clone(), hidden_states,
) )
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
......
...@@ -1840,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1840,11 +1840,6 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
hidden_states = hidden_states.clone()
return hidden_states, residual return hidden_states, residual
def op_comm_prepare_attn( def op_comm_prepare_attn(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment