trt_reduce_internal.cuh 3.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/* Copyright 2025 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
/*
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once
35

36
37
38
39
40
41
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>

namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
yizhang2077's avatar
yizhang2077 committed
42
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
43
constexpr size_t MAX_RANKS_PER_NODE = 8;
yizhang2077's avatar
yizhang2077 committed
44
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
45
46
47
48
49
50
51
52

enum class AllReduceStrategyType : int8_t {
  RING = 0,
  ONESHOT = 1,
  TWOSHOT = 2,
  AUTO = 3,
};

53
54
55
56
struct RankData {
  void* ptrs[MAX_RANKS_PER_NODE];
};

57
58
59
60
61
62
63
64
65
66
struct AllReduceParams {
  size_t elts_size;
  size_t elts_total;
  size_t elts_per_rank;
  size_t elts_per_block;
  size_t rank_offset;
  size_t ranks_per_node, rank, local_rank;
  uint32_t barrier_flag;
  uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
  uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
67
68
  uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE];
  RankData* peer_comm_buffer_ptrs;
69
70
  void* local_input_buffer_ptr;
  void* local_output_buffer_ptr;
71
  bool is_capturing;
72
73
74
75
};

inline size_t GetMaxRequiredWorkspaceSize(int world_size) {
  if (world_size <= 2) {
yizhang2077's avatar
yizhang2077 committed
76
    return 16 * 1024 * 1024;
77
  }
yizhang2077's avatar
yizhang2077 committed
78
  return 8 * 1024 * 1024;
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
}

inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) {
  const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size);

  if (message_size > maxWorkspaceSize) {
    assert(false && "Custom allreduce do not ring currently");
    return AllReduceStrategyType::RING;
  }

  if (world_size <= 2) {
    return AllReduceStrategyType::ONESHOT;
  }

  if (world_size <= 4) {
yizhang2077's avatar
yizhang2077 committed
94
    if (message_size < 1 * 1024 * 1024) {
95
96
97
98
99
      return AllReduceStrategyType::ONESHOT;
    }
    return AllReduceStrategyType::TWOSHOT;
  }

yizhang2077's avatar
yizhang2077 committed
100
  if (message_size < 512 * 1024) {
101
102
103
104
105
    return AllReduceStrategyType::ONESHOT;
  }
  return AllReduceStrategyType::TWOSHOT;
}

106
107
void trtCustomAllReduce(
    AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, cudaStream_t stream);
108
109

}  // namespace trt_llm