Unverified Commit a3d7f4b6 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

fix communicator for non-dp lm head (#6662)

parent b18416fb
......@@ -50,10 +50,18 @@ class ScatterMode(Enum):
FULL = auto()
@staticmethod
def model_input_output():
"""The scatter mode for model forward pass input and output data"""
def model_input_mode():
"""The scatter mode for model input data"""
return ScatterMode.TP_ATTN_FULL
@staticmethod
def model_output_mode():
"""The scatter mode for model output data"""
if global_server_args_dict["enable_dp_lm_head"]:
return ScatterMode.TP_ATTN_FULL
else:
return ScatterMode.FULL
@dataclass
class _LayerModeComputationContext:
......@@ -95,7 +103,7 @@ class LayerScatterModes:
@classmethod
def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
if context.layer_id == 0:
return ScatterMode.model_input_output()
return ScatterMode.model_input_mode()
return cls._compute_layer_output_mode(context.previous_layer())
@classmethod
......@@ -126,7 +134,7 @@ class LayerScatterModes:
def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
mlp_mode = cls._compute_mlp_mode(context)
if context.layer_id == context.num_layers - 1:
return ScatterMode.model_input_output()
return ScatterMode.model_output_mode()
if mlp_mode == ScatterMode.SCATTERED:
return ScatterMode.SCATTERED
if mlp_mode == ScatterMode.FULL:
......
......@@ -451,7 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states, residual = model_forward_maybe_tbo(
layers=self.layers,
enable_tbo=True,
input_data_scatter_mode=ScatterMode.model_input_output(),
input_data_scatter_mode=ScatterMode.model_input_mode(),
positions=positions,
forward_batch=forward_batch,
hidden_states=hidden_states,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment