@@ -27,6 +27,7 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather_into_tensor,
attn_tp_reduce_scatter_tensor,
dp_gather_partial,
dp_reduce_scatter_tensor,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
...
...
@@ -149,10 +150,13 @@ class LayerCommunicator:
layer_scatter_modes:LayerScatterModes,
input_layernorm:torch.nn.Module,
post_attention_layernorm:torch.nn.Module,
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.