Unverified Commit fa6723f0 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "fix communicator for non-dp lm head (#6662)" (#6677)

parent 673ff668
......@@ -50,18 +50,10 @@ class ScatterMode(Enum):
FULL = auto()
@staticmethod
def model_input_mode():
"""The scatter mode for model input data"""
def model_input_output():
"""The scatter mode for model forward pass input and output data"""
return ScatterMode.TP_ATTN_FULL
@staticmethod
def model_output_mode():
"""The scatter mode for model output data"""
if global_server_args_dict["enable_dp_lm_head"]:
return ScatterMode.TP_ATTN_FULL
else:
return ScatterMode.FULL
@dataclass
class _LayerModeComputationContext:
......@@ -103,7 +95,7 @@ class LayerScatterModes:
@classmethod
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
if context.layer_id == 0:
return ScatterMode.model_input_mode()
return ScatterMode.model_input_output()
return cls._compute_layer_output_mode(context.previous_layer())
@classmethod
......@@ -134,7 +126,7 @@ class LayerScatterModes:
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
mlp_mode = cls._compute_mlp_mode(context)
if context.layer_id == context.num_layers - 1:
return ScatterMode.model_output_mode()
return ScatterMode.model_input_output()
if mlp_mode == ScatterMode.SCATTERED:
return ScatterMode.SCATTERED
if mlp_mode == ScatterMode.FULL:
......
......@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers,
enable_tbo=True,
input_data_scatter_mode=ScatterMode.model_input_mode(),
input_data_scatter_mode=ScatterMode.model_input_output(),
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment