interface.cpp 1.87 KB
Newer Older
1
#include <ATen/record_function.h>
2
#include <torch/all.h>
3
4
5
6
7
8
9
10
11
12
13

#include "shm.h"

// Communication settings
static int world_rank = -1;
static int world_size = -1;

static bool is_initialized = false;

static bool all_ranks_local_p = false;

blzheng's avatar
blzheng committed
14
void initialize(int64_t size, int64_t rank) {
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  if (is_initialized) {
    return;
  }

  // Check whether all ranks is on the same physical machine.
  // If true, we will use an SHM based low latency allreduce

  auto ls_string = std::getenv("LOCAL_SIZE");
  int ls = 0;
  if (ls_string != NULL) {
    ls = std::stoi(std::getenv("LOCAL_SIZE"));
  }

  if (size >= 1 && size == ls) {
    all_ranks_local_p = true;
  }

  world_size = size;
  world_rank = rank;
  is_initialized = true;

36
  const char* addr_string = std::getenv("MASTER_ADDR");
37
38
39
  if (addr_string == NULL) {
    addr_string = "";
  }
40
  const char* port_string = std::getenv("MASTER_PORT");
41
42
43
44
45
46
47
48
49
  if (port_string == NULL) {
    port_string = "";
  }

  if (all_ranks_local_p) {
    shm_initialize(size, rank, addr_string, port_string);
  }
}

50
void shm_allreduce(torch::Tensor& data, int64_t op) {
51
52
  RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));

blzheng's avatar
blzheng committed
53
  TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");
54
55

  auto numel = data.numel();
56
57
  int data_size = numel * data.element_size();
  all_reduce_outer_loop(data, numel, data_size);
58
59
60
61

  return;
}

62
torch::Tensor shm_allgather(torch::Tensor& data, int64_t dim) {
63
64
65
  RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));

  auto numel = data.numel();
66
  int data_size = numel * data.element_size();
67
68
69
70
71
72
73
74
  if (dim < 0) {
    dim += data.dim();
  }
  std::vector<int64_t> result_shape = data.sizes().vec();
  result_shape[dim] *= world_size;
  torch::Tensor result_tensor = torch::empty(result_shape, data.options());
  return all_gather(result_tensor, data, dim, numel, data_size);
}