tuner.cc 4.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
/*************************************************************************
 * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Runtime tuner for RCCL - Simple brute force search for algorithm and protocol
 ************************************************************************/

#include "comm.h"
#include "core.h"
#include <cfloat>

// Hash function for workload identification
static uint64_t hashWorkload(ncclFunc_t coll, size_t count, ncclDataType_t datatype) {
  uint64_t hash = 0xdeadbeef;
  hash = (hash << 8) | (uint64_t)coll;
  hash ^= (uint64_t)count;
  hash = (hash << 8) | (uint64_t)datatype;
  return hash;
}

// Get algorithm and protocol configuration for testing
static void getTestConfig(int step, int* algo, int* proto) {
  // Test configurations: Ring+Simple, Ring+LL, Tree+Simple, Tree+LL
  int algoList[] = {NCCL_ALGO_RING, NCCL_ALGO_TREE};
  int protoList[] = {NCCL_PROTO_SIMPLE, NCCL_PROTO_LL};
  
  int algoIdx = step / 2;
  int protoIdx = step % 2;
  
  *algo = algoList[algoIdx];
  *proto = protoList[protoIdx];
}

ncclResult_t ncclTunerGetConfig(struct ncclComm* comm, struct ncclInfo* info, 
                                 int* algo, int* proto, bool* needsTuning) {
  if (!comm->tuner.enabled) {
    *needsTuning = false;
    return ncclSuccess;
  }
  
  // Calculate workload hash
  uint64_t workloadHash = hashWorkload(info->coll, info->count, info->datatype);
  
  // Check if we have a cached result
  auto it = comm->tuner.workloadCache->find(workloadHash);
  
  if (it != comm->tuner.workloadCache->end()) {
    // Use cached configuration
    int bestConfig = it->second;
    *algo = bestConfig / 10;
    *proto = bestConfig % 10;
    *needsTuning = false;
    
    INFO(NCCL_TUNER, "Rank %d: Using cached config for workload %llx: algo=%d proto=%d",
         comm->rank, (unsigned long long)workloadHash, *algo, *proto);
    return ncclSuccess;
  }
  
  // New workload - start tuning
  if (!comm->tuner.isSearching || comm->tuner.currentWorkloadHash != workloadHash) {
    // Start new search
    comm->tuner.isSearching = true;
    comm->tuner.currentWorkloadHash = workloadHash;
    comm->tuner.searchStep = 0;
    comm->tuner.bestTime = FLT_MAX;
    
    INFO(NCCL_INIT, "Rank %d: New workload %llx detected, starting tuning (coll=%d count=%zu dtype=%d)",
         comm->rank, (unsigned long long)workloadHash, info->coll, info->count, info->datatype);
  }
  
  const int totalConfigs = 4;  // 2 algos × 2 protos
  
  if (comm->tuner.searchStep < totalConfigs) {
    // Get current test configuration
    getTestConfig(comm->tuner.searchStep, algo, proto);
    comm->tuner.currentAlgo = *algo;
    comm->tuner.currentProto = *proto;
    *needsTuning = true;
    
    INFO(NCCL_TUNER, "Rank %d: Testing config %d/%d: algo=%d proto=%d",
         comm->rank, comm->tuner.searchStep + 1, totalConfigs, *algo, *proto);
    
    return ncclSuccess;
  }
  
  // Search complete - use best configuration
  *algo = comm->tuner.bestAlgo;
  *proto = comm->tuner.bestProto;
  *needsTuning = false;
  
  return ncclSuccess;
}

ncclResult_t ncclTunerRecordPerformance(struct ncclComm* comm, float elapsedMs) {
  if (!comm->tuner.enabled || !comm->tuner.isSearching) {
    return ncclSuccess;
  }
  
  INFO(NCCL_TUNER, "Rank %d: Config %d (algo=%d proto=%d) time: %.3f ms",
       comm->rank, comm->tuner.searchStep, 
       comm->tuner.currentAlgo, comm->tuner.currentProto, elapsedMs);
  
  // Update best configuration
  if (elapsedMs < comm->tuner.bestTime) {
    comm->tuner.bestTime = elapsedMs;
    comm->tuner.bestAlgo = comm->tuner.currentAlgo;
    comm->tuner.bestProto = comm->tuner.currentProto;
  }
  
  comm->tuner.searchStep++;
  
  const int totalConfigs = 4;
  if (comm->tuner.searchStep >= totalConfigs) {
    // Tuning complete
    INFO(NCCL_INIT, "Rank %d: Tuning complete for workload %llx! Best: algo=%d proto=%d time=%.3f ms",
         comm->rank, (unsigned long long)comm->tuner.currentWorkloadHash,
         comm->tuner.bestAlgo, comm->tuner.bestProto, comm->tuner.bestTime);
    
    // Cache the result
    int bestConfig = comm->tuner.bestAlgo * 10 + comm->tuner.bestProto;
    (*comm->tuner.workloadCache)[comm->tuner.currentWorkloadHash] = bestConfig;
    
    comm->tuner.isSearching = false;
  }
  
  return ncclSuccess;
}