@@ -2303,6 +2303,51 @@ class Withable(Generic[T]):
self._value=None
defrequire_mlp_tp_gather(server_args):
"""
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
"""
ifserver_args.enable_dp_attention:
assertserver_args.dp_size>1,"dp_size must be greater than 1"
if(
server_args.moe_dense_tp_sizeisNone
):# TODO(ch-wan): some MoE models do not have dense layers