communicator.py 23.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from dataclasses import dataclass
from enum import Enum, auto
17
18
from functools import partial
from typing import Dict, Optional
19

20
import torch
21
22
23
24
25
26

from sglang.srt.distributed import (
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.dp_attention import (
Cheng Wan's avatar
Cheng Wan committed
27
28
    attn_tp_all_gather_into_tensor,
    attn_tp_reduce_scatter_tensor,
29
    dp_gather_partial,
30
    dp_reduce_scatter_tensor,
31
    dp_scatter,
32
    get_attention_dp_size,
33
34
    get_attention_tp_rank,
    get_attention_tp_size,
35
36
    get_global_dp_buffer,
    get_local_dp_buffer,
37
    is_dp_attention_enabled,
38
)
39
40
41
42
from sglang.srt.layers.moe import (
    get_moe_a2a_backend,
    should_use_flashinfer_cutlass_moe_fp4_allgather,
)
43
from sglang.srt.managers.schedule_batch import global_server_args_dict
44
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
from sglang.srt.utils import (
46
    get_bool_env_var,
47
48
    is_cuda,
    is_flashinfer_available,
49
50
    is_gfx95_supported,
    is_hip,
51
52
    is_sm90_supported,
    is_sm100_supported,
53
    prepare_weight_cache,
54
)
55
56

_is_flashinfer_available = is_flashinfer_available()
57
_is_sm90_supported = is_cuda() and is_sm90_supported()
58
_is_sm100_supported = is_cuda() and is_sm100_supported()
59
60
61
62
63
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()

if _use_aiter and _is_gfx95_supported:
    from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
64

65
66
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048

67
68

class ScatterMode(Enum):
fzyzcjy's avatar
fzyzcjy committed
69
70
71
72
73
74
75
76
    """
    Suppose we have TP=4, DP=2, enable-dp-attention, and the system handles seq a,b,c,d
    Model input/output: [ab, ab, cd, cd] for four ranks respectively
    SCATTERED: [a, b, c, d]
    TP_ATTN_FULL: [ab, ab, cd, cd], i.e. all ranks inside a TP attn group have full data of the group
    FULL: [abcd, abcd, abcd, abcd]
    """

77
78
79
80
    SCATTERED = auto()
    TP_ATTN_FULL = auto()
    FULL = auto()

fzyzcjy's avatar
fzyzcjy committed
81
    @staticmethod
82
83
    def model_input_output():
        """The scatter mode for model forward pass input and output data"""
fzyzcjy's avatar
fzyzcjy committed
84
85
        return ScatterMode.TP_ATTN_FULL

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

@dataclass
class _LayerModeComputationContext:
    num_layers: int
    layer_id: int
    is_layer_sparse: bool
    is_previous_layer_sparse: Optional[bool]

    def previous_layer(self):
        assert self.is_previous_layer_sparse is not None
        return _LayerModeComputationContext(
            layer_id=self.layer_id - 1,
            is_layer_sparse=self.is_previous_layer_sparse,
            is_previous_layer_sparse=None,
            num_layers=self.num_layers,
        )


@dataclass
class LayerScatterModes:
    layer_input_mode: ScatterMode
    attn_mode: ScatterMode
    # Can be further split into e.g. mlp_input_mode and mlp_output_mode if needed
    mlp_mode: ScatterMode
    middle_residual_mode: ScatterMode
    layer_output_mode: ScatterMode

    @classmethod
    def init_new(cls, **kwargs):
        context = _LayerModeComputationContext(**kwargs)
        return cls(
            layer_input_mode=cls._compute_layer_input_mode(context),
            attn_mode=ScatterMode.TP_ATTN_FULL,
            mlp_mode=cls._compute_mlp_mode(context),
            middle_residual_mode=cls._compute_middle_residual_mode(context),
            layer_output_mode=cls._compute_layer_output_mode(context),
        )

    @classmethod
    def _compute_layer_input_mode(cls, context: _LayerModeComputationContext):
        if context.layer_id == 0:
127
            return ScatterMode.model_input_output()
128
129
130
131
132
133
134
        return cls._compute_layer_output_mode(context.previous_layer())

    @classmethod
    def _compute_mlp_mode(cls, context: _LayerModeComputationContext):
        if context.is_layer_sparse:
            return (
                ScatterMode.SCATTERED
135
136
137
138
139
                if (
                    # Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
                    not get_moe_a2a_backend().is_none()
                    or should_use_flashinfer_cutlass_moe_fp4_allgather()
                )
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
                else ScatterMode.FULL
            )
        else:
            return (
                ScatterMode.SCATTERED
                if enable_moe_dense_fully_dp()
                else ScatterMode.FULL
            )

    @classmethod
    def _compute_middle_residual_mode(cls, context: _LayerModeComputationContext):
        mlp_mode = cls._compute_mlp_mode(context)
        if mlp_mode == ScatterMode.SCATTERED:
            return ScatterMode.SCATTERED
        if mlp_mode == ScatterMode.FULL:
            return ScatterMode.TP_ATTN_FULL
        raise NotImplementedError

    @classmethod
    def _compute_layer_output_mode(cls, context: _LayerModeComputationContext):
        mlp_mode = cls._compute_mlp_mode(context)
        if context.layer_id == context.num_layers - 1:
162
            return ScatterMode.model_input_output()
163
164
165
166
167
168
169
170
        if mlp_mode == ScatterMode.SCATTERED:
            return ScatterMode.SCATTERED
        if mlp_mode == ScatterMode.FULL:
            return ScatterMode.TP_ATTN_FULL
        raise NotImplementedError


def enable_moe_dense_fully_dp():
171
    return global_server_args_dict["moe_dense_tp_size"] == 1
172
173
174
175
176
177
178
179


class LayerCommunicator:
    def __init__(
        self,
        layer_scatter_modes: LayerScatterModes,
        input_layernorm: torch.nn.Module,
        post_attention_layernorm: torch.nn.Module,
180
181
        # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
        allow_reduce_scatter: bool = False,
182
        is_last_layer: bool = False,
183
184
185
186
    ):
        self.layer_scatter_modes = layer_scatter_modes
        self.input_layernorm = input_layernorm
        self.post_attention_layernorm = post_attention_layernorm
187
        self.allow_reduce_scatter = allow_reduce_scatter
188
        self.is_last_layer = is_last_layer
189

fzyzcjy's avatar
fzyzcjy committed
190
191
        self._context = CommunicateContext.init_new()
        self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
192
193
194
195
196
            input_mode=self.layer_scatter_modes.layer_input_mode,
            output_mode=self.layer_scatter_modes.attn_mode,
            context=self._context,
        )
        self._communicate_with_all_reduce_and_layer_norm_fn = (
fzyzcjy's avatar
fzyzcjy committed
197
            CommunicateWithAllReduceAndLayerNormFn.get_fn(
198
199
200
201
202
203
204
205
                hidden_states_input_mode=self.layer_scatter_modes.attn_mode,
                residual_input_mode=self.layer_scatter_modes.layer_input_mode,
                hidden_states_output_mode=self.layer_scatter_modes.mlp_mode,
                residual_output_mode=self.layer_scatter_modes.middle_residual_mode,
                context=self._context,
            )
        )
        self._communicate_summable_tensor_pair_fn = (
fzyzcjy's avatar
fzyzcjy committed
206
            CommunicateSummableTensorPairFn.get_fn(
207
208
209
210
211
212
213
                hidden_states_input_mode=self.layer_scatter_modes.mlp_mode,
                residual_input_mode=self.layer_scatter_modes.middle_residual_mode,
                output_mode=self.layer_scatter_modes.layer_output_mode,
                context=self._context,
            )
        )

214
215
216
217
218
    def prepare_attn(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
219
        qaunt_format: str = "",
220
221
222
223
    ):
        if hidden_states.shape[0] == 0:
            residual = hidden_states
        else:
224
225
226
227
228
229
230
231
232
233
            if (
                residual is not None
                and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
                and hidden_states._sglang_needs_allreduce_fusion
            ):
                hidden_states, residual = (
                    self.input_layernorm.forward_with_allreduce_fusion(
                        hidden_states, residual
                    )
                )
234
            else:
235
236
                if residual is None:
                    residual = hidden_states
237
238
239
240
241
242
243
244
245
246
247
248
249

                    if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
                        hidden_states = fused_rms_mxfp4_quant(
                            hidden_states,
                            self.input_layernorm.weight,
                            self.input_layernorm.variance_epsilon,
                            None,
                            None,
                            None,
                            None,
                        )
                    else:
                        hidden_states = self.input_layernorm(hidden_states)
250
                else:
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                    if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
                        hidden_states, residual = fused_rms_mxfp4_quant(
                            hidden_states,
                            self.input_layernorm.weight,
                            self.input_layernorm.variance_epsilon,
                            None,
                            None,
                            None,
                            residual,
                        )
                    else:
                        hidden_states, residual = self.input_layernorm(
                            hidden_states, residual
                        )
265

266
        hidden_states = self._communicate_simple_fn(
267
268
            hidden_states=hidden_states,
            forward_batch=forward_batch,
269
            context=self._context,
270
271
272
273
274
275
276
277
278
        )

        return hidden_states, residual

    def prepare_mlp(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
279
        cache=None,
280
    ):
281
282
283
        if cache is not None:
            self._context.cache = cache

284
        return self._communicate_with_all_reduce_and_layer_norm_fn(
285
286
287
288
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=forward_batch,
            layernorm=self.post_attention_layernorm,
289
            context=self._context,
290
291
292
293
294
295
296
297
        )

    def postprocess_layer(
        self,
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
    ):
298
        return self._communicate_summable_tensor_pair_fn(
299
300
301
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=forward_batch,
302
            context=self._context,
303
304
305
306
307
308
309
310
311
            allow_reduce_scatter=self.allow_reduce_scatter,
        )

    def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
        return (
            self.allow_reduce_scatter
            and self._communicate_summable_tensor_pair_fn
            is CommunicateSummableTensorPairFn._scatter_hidden_states
            and forward_batch.dp_padding_mode.is_max_len()
312
313
        )

314
315
316
    def should_fuse_mlp_allreduce_with_next_layer(
        self, forward_batch: ForwardBatch
    ) -> bool:
317
        speculative_algo = global_server_args_dict.get("speculative_algorithm", None)
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        if (
            is_dp_attention_enabled()
            and speculative_algo is not None
            and speculative_algo.is_eagle()
        ):
            return False

        batch_size = (
            forward_batch.input_ids.shape[0]
            if hasattr(forward_batch, "input_ids")
            else 0
        )
        if batch_size > FUSE_ALLREDUCE_MAX_BATCH_SIZE:
            return False

        static_conditions_met = (
            (not self.is_last_layer)
            and (self._context.tp_size > 1)
336
            and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
337
338
339
340
341
342
343
344
345
346
347
348
            and _is_flashinfer_available
        )

        if not static_conditions_met:
            return False

        return (
            batch_size > 0
            and batch_size <= FUSE_ALLREDUCE_MAX_BATCH_SIZE
            and (not self.is_last_layer)
        )

349
350

@dataclass
fzyzcjy's avatar
fzyzcjy committed
351
class CommunicateContext:
352
    process_group_sizes: Dict[ScatterMode, int]
353
354
    attn_tp_rank: int
    attn_tp_size: int
355
    attn_dp_size: int
356
    tp_size: int
357
    cache = None
358

359
    def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
360
361
        return self.process_group_sizes[a] == self.process_group_sizes[b]

fzyzcjy's avatar
fzyzcjy committed
362
363
364
365
    @classmethod
    def init_new(cls):
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()
366
        attn_dp_size = get_attention_dp_size()
fzyzcjy's avatar
fzyzcjy committed
367
368
369
370
        tp_size = get_tensor_model_parallel_world_size()
        process_group_sizes = {
            ScatterMode.SCATTERED: 1,
            ScatterMode.TP_ATTN_FULL: attn_tp_size,
371
            # TODO: support --moe-dense-tp-size > 1
fzyzcjy's avatar
fzyzcjy committed
372
373
374
375
376
377
            ScatterMode.FULL: tp_size,
        }
        return cls(
            process_group_sizes=process_group_sizes,
            attn_tp_rank=attn_tp_rank,
            attn_tp_size=attn_tp_size,
378
            attn_dp_size=attn_dp_size,
fzyzcjy's avatar
fzyzcjy committed
379
380
381
            tp_size=tp_size,
        )

382

fzyzcjy's avatar
fzyzcjy committed
383
class CommunicateSimpleFn:
384
385
386
387
    @staticmethod
    def get_fn(
        input_mode: ScatterMode,
        output_mode: ScatterMode,
fzyzcjy's avatar
fzyzcjy committed
388
        context: CommunicateContext,
389
    ):
390
        if context.is_same_group_size(input_mode, output_mode):
fzyzcjy's avatar
fzyzcjy committed
391
            return CommunicateSimpleFn._trivial
392
393
394
395

        if (input_mode == ScatterMode.SCATTERED) and (
            output_mode == ScatterMode.TP_ATTN_FULL
        ):
fzyzcjy's avatar
fzyzcjy committed
396
            return CommunicateSimpleFn._scattered_to_tp_attn_full
397
398
399

        raise NotImplementedError(f"{input_mode=} {output_mode=}")

400
401
402
403
    @staticmethod
    def _trivial(
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
fzyzcjy's avatar
fzyzcjy committed
404
        context: CommunicateContext,
405
406
407
408
409
410
411
    ) -> torch.Tensor:
        return hidden_states

    @staticmethod
    def _scattered_to_tp_attn_full(
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
fzyzcjy's avatar
fzyzcjy committed
412
        context: CommunicateContext,
413
414
    ) -> torch.Tensor:
        hidden_states, local_hidden_states = (
415
            get_local_dp_buffer(),
416
417
            hidden_states,
        )
Cheng Wan's avatar
Cheng Wan committed
418
419
        attn_tp_all_gather_into_tensor(
            hidden_states,
420
421
422
423
424
            local_hidden_states,
        )
        return hidden_states


fzyzcjy's avatar
fzyzcjy committed
425
class CommunicateWithAllReduceAndLayerNormFn:
426
427
428
429
430
    """Besides communication, needs to
    1. All reduce in tp_attn_group on hidden_states
    2. Apply layer norm
    """

431
432
433
434
435
436
    @staticmethod
    def get_fn(
        hidden_states_input_mode: ScatterMode,
        residual_input_mode: ScatterMode,
        hidden_states_output_mode: ScatterMode,
        residual_output_mode: ScatterMode,
fzyzcjy's avatar
fzyzcjy committed
437
        context: CommunicateContext,
438
    ):
439
440
441
442
443
444
445
446

        if (
            context.is_same_group_size(
                hidden_states_input_mode, hidden_states_output_mode
            )
            and context.is_same_group_size(residual_input_mode, residual_output_mode)
            and context.attn_tp_size == 1
        ):
fzyzcjy's avatar
fzyzcjy committed
447
            return CommunicateWithAllReduceAndLayerNormFn._simple
448
449
450

        if (
            (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
451
452
453
            and (
                residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
            )
454
455
456
            and (hidden_states_output_mode == ScatterMode.FULL)
            and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
        ):
457
458
459
460
            return partial(
                CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
                residual_input_mode=residual_input_mode,
            )
461
462
463
464
465
466
467
468
469

        if (
            (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
            and (
                residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
            )
            and (hidden_states_output_mode == ScatterMode.SCATTERED)
            and (residual_output_mode == ScatterMode.SCATTERED)
        ):
470
            return partial(
fzyzcjy's avatar
fzyzcjy committed
471
                CommunicateWithAllReduceAndLayerNormFn._scatter_hidden_states_and_residual,
472
473
                residual_input_mode=residual_input_mode,
            )
474
475

        raise NotImplementedError(
476
            f"{hidden_states_input_mode=} {residual_input_mode=} {hidden_states_output_mode=} {residual_output_mode=}"
477
478
        )

479
480
481
482
483
484
    @staticmethod
    def _simple(
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
        layernorm: torch.nn.Module,
fzyzcjy's avatar
fzyzcjy committed
485
        context: CommunicateContext,
486
487
488
489
490
491
492
    ):
        # TODO move these `if shape != 0` into LayerNorm itself
        if hidden_states.shape[0] != 0:
            hidden_states, residual = layernorm(hidden_states, residual)
        return hidden_states, residual

    @staticmethod
493
    def _gather_hidden_states_and_residual(
494
495
496
497
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
        layernorm: torch.nn.Module,
fzyzcjy's avatar
fzyzcjy committed
498
        context: CommunicateContext,
499
500
        *,
        residual_input_mode,
501
    ):
502
503
        if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
            residual, local_residual = (
504
                get_local_dp_buffer(),
505
506
                residual,
            )
Cheng Wan's avatar
Cheng Wan committed
507
            attn_tp_all_gather_into_tensor(residual, local_residual)
508
        if context.attn_dp_size != 1:
509
510
            if context.attn_tp_rank == 0:
                hidden_states += residual
511
512
513

            # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
            use_layer_norm_before_gather = context.attn_tp_size == 1
514
515
516
            if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
                residual = hidden_states
                hidden_states = layernorm(hidden_states)
517
            hidden_states, local_hidden_states = (
518
                get_global_dp_buffer(),
519
520
521
                hidden_states,
            )
            dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
522
523
524
525
526

            if not use_layer_norm_before_gather:
                dp_scatter(residual, hidden_states, forward_batch)
                if hidden_states.shape[0] != 0:
                    hidden_states = layernorm(hidden_states)
527
        else:
528
529
            # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
            # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
530
            if (
531
                (_is_sm100_supported or _is_sm90_supported)
532
533
                and _is_flashinfer_available
                and hasattr(layernorm, "forward_with_allreduce_fusion")
534
                and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
535
                and hidden_states.shape[0] <= 4096
536
537
538
539
540
541
            ):
                hidden_states, residual = layernorm.forward_with_allreduce_fusion(
                    hidden_states, residual
                )
            else:
                hidden_states = tensor_model_parallel_all_reduce(hidden_states)
542
543
                if context.cache is not None:
                    _ = prepare_weight_cache(hidden_states, context.cache)
544
                hidden_states, residual = layernorm(hidden_states, residual)
545
546
547
548
549
550
551
552
        return hidden_states, residual

    @staticmethod
    def _scatter_hidden_states_and_residual(
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
        layernorm: torch.nn.Module,
fzyzcjy's avatar
fzyzcjy committed
553
        context: CommunicateContext,
554
555
556
        *,
        residual_input_mode,
    ):
Cheng Wan's avatar
Cheng Wan committed
557
558
559
560
561
        input_hidden_states = hidden_states
        hidden_states = hidden_states.tensor_split(context.attn_tp_size)[
            context.attn_tp_rank
        ]
        attn_tp_reduce_scatter_tensor(hidden_states, input_hidden_states)
562
563
564
565
566
        if residual_input_mode == ScatterMode.TP_ATTN_FULL:
            residual = residual.tensor_split(context.attn_tp_size)[context.attn_tp_rank]
        if hidden_states.shape[0] != 0:
            hidden_states, residual = layernorm(hidden_states, residual)
        return hidden_states, residual
567
568


fzyzcjy's avatar
fzyzcjy committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
class CommunicateSummableTensorPairFn:
    """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""

    @classmethod
    def execute(
        cls,
        hidden_states_input_mode,
        residual_input_mode,
        output_mode,
        context,
        **kwargs,
    ):
        return cls.get_fn(
            hidden_states_input_mode=hidden_states_input_mode,
            residual_input_mode=residual_input_mode,
            output_mode=output_mode,
            context=context,
        )(context=context, **kwargs)
587

588
589
590
591
592
    @staticmethod
    def get_fn(
        hidden_states_input_mode: ScatterMode,
        residual_input_mode: ScatterMode,
        output_mode: ScatterMode,
fzyzcjy's avatar
fzyzcjy committed
593
        context: CommunicateContext,
594
    ):
595
596
597
        if context.is_same_group_size(
            hidden_states_input_mode, output_mode
        ) and context.is_same_group_size(residual_input_mode, output_mode):
fzyzcjy's avatar
fzyzcjy committed
598
            return CommunicateSummableTensorPairFn._trivial
599
600
601
602
603
604

        if (
            (hidden_states_input_mode == ScatterMode.FULL)
            and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
            and (output_mode == ScatterMode.TP_ATTN_FULL)
        ):
fzyzcjy's avatar
fzyzcjy committed
605
            return CommunicateSummableTensorPairFn._scatter_hidden_states
606
607
608
609
610
611

        if (
            (hidden_states_input_mode == ScatterMode.SCATTERED)
            and (residual_input_mode == ScatterMode.SCATTERED)
            and (output_mode == ScatterMode.TP_ATTN_FULL)
        ):
fzyzcjy's avatar
fzyzcjy committed
612
            return CommunicateSummableTensorPairFn._gather
613

614
615
616
617
618
619
620
        if (
            (hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
            and (residual_input_mode == ScatterMode.TP_ATTN_FULL)
            and (output_mode == ScatterMode.SCATTERED)
        ):
            return CommunicateSummableTensorPairFn._scatter

621
622
623
624
        raise NotImplementedError(
            f"{hidden_states_input_mode=} {residual_input_mode=} {output_mode=}"
        )

625
626
627
628
629
    @staticmethod
    def _trivial(
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
fzyzcjy's avatar
fzyzcjy committed
630
        context: CommunicateContext,
631
        **kwargs,
632
633
634
635
636
637
638
639
    ):
        return hidden_states, residual

    @staticmethod
    def _scatter_hidden_states(
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
fzyzcjy's avatar
fzyzcjy committed
640
        context: CommunicateContext,
641
        allow_reduce_scatter: bool = False,
642
643
    ):
        hidden_states, global_hidden_states = (
644
            get_local_dp_buffer(),
645
646
            hidden_states,
        )
647
648
649
650
651
        if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
            # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
            dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
        else:
            dp_scatter(hidden_states, global_hidden_states, forward_batch)
652
653
654
655
656
657
658
        return hidden_states, residual

    @staticmethod
    def _gather(
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
fzyzcjy's avatar
fzyzcjy committed
659
        context: CommunicateContext,
660
        **kwargs,
661
662
663
664
    ):
        hidden_states += residual
        residual = None
        hidden_states, local_hidden_states = (
665
            get_local_dp_buffer(),
666
667
            hidden_states,
        )
Cheng Wan's avatar
Cheng Wan committed
668
669
        attn_tp_all_gather_into_tensor(
            hidden_states,
670
671
672
            local_hidden_states,
        )
        return hidden_states, residual
673
674
675
676
677
678
679
680
681
682
683
684

    @staticmethod
    def _scatter(
        hidden_states: torch.Tensor,
        residual: torch.Tensor,
        forward_batch: ForwardBatch,
        context: CommunicateContext,
    ):
        assert residual is None, "not yet handled residual!=None"
        tensor_list = list(hidden_states.tensor_split(context.attn_tp_size))
        hidden_states = tensor_list[context.attn_tp_rank]
        return hidden_states, residual