"sgl-kernel/vscode:/vscode.git/clone" did not exist on "e0b2d3eebebd3d4efc7e323ad2dee605b607f394"
mpiUtils.cpp 17.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
/*
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <numeric>
#include <unordered_set>

#include "tensorrt_llm/common/mpiUtils.h"

#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"

#include <csignal>
#include <cstdlib>
#include <mutex>
#include <thread>
#include <type_traits>
#ifndef _WIN32
#include <unistd.h>
#endif

// We rely on SizeType32 being int32_t in some places with weak type checking,
// i.e. we're passing void ptr to some function. To prevent mysterious errors
// in the future, we trigger a compilation error here if SizeType32 isn't int32_t.
static_assert(std::is_same<tensorrt_llm::runtime::SizeType32, std::int32_t>::value);

namespace tensorrt_llm::mpi
{

MPI_Datatype getMpiDtype(MpiType dtype)
{
#if ENABLE_MULTI_DEVICE
    static std::unordered_map<MpiType, MPI_Datatype> const dtype_map{
        {MpiType::kBYTE, MPI_BYTE},
        {MpiType::kHALF, MPI_UINT16_T},
        {MpiType::kFLOAT, MPI_FLOAT},
        {MpiType::kDOUBLE, MPI_DOUBLE},
        {MpiType::kBOOL, MPI_C_BOOL},
        {MpiType::kINT8, MPI_INT8_T},
        {MpiType::kUINT8, MPI_UINT8_T},
        {MpiType::kINT32, MPI_INT32_T},
        {MpiType::kUINT32, MPI_UINT32_T},
        {MpiType::kINT64, MPI_INT64_T},
        {MpiType::kUINT64, MPI_UINT64_T},
        {MpiType::kFP8, MPI_UINT8_T},
        {MpiType::kBF16, MPI_UINT16_T},
        {MpiType::kCHAR, MPI_CHAR},
    };
    return dtype_map.at(dtype);
#else
    TLLM_THROW("Multi device support is disabled.");
#endif
}

MPI_Op getMpiOp(MpiOp op)
{
#if ENABLE_MULTI_DEVICE
    static std::unordered_map<MpiOp, MPI_Op> const op_map{
        {MpiOp::NULLOP, MPI_OP_NULL},
        {MpiOp::MAX, MPI_MAX},
        {MpiOp::MIN, MPI_MIN},
        {MpiOp::SUM, MPI_SUM},
        {MpiOp::PROD, MPI_PROD},
        {MpiOp::LAND, MPI_LAND},
        {MpiOp::BAND, MPI_BAND},
        {MpiOp::LOR, MPI_LOR},
        {MpiOp::BOR, MPI_BOR},
        {MpiOp::LXOR, MPI_LXOR},
        {MpiOp::BXOR, MPI_BXOR},
        {MpiOp::MINLOC, MPI_MINLOC},
        {MpiOp::MAXLOC, MPI_MAXLOC},
        {MpiOp::REPLACE, MPI_REPLACE},
    };
    return op_map.at(op);
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}

namespace
{

bool mpiInitialized = false;
std::recursive_mutex mpiMutex;

MpiComm initLocalSession()
{
#if ENABLE_MULTI_DEVICE
    MPI_Comm localComm = nullptr;
    MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm);
    MpiComm localSession{localComm, false};
#else
    MpiComm localSession{COMM_SESSION, false};
#endif // ENABLE_MULTI_DEVICE
    return localSession;
}

} // namespace

std::vector<int> getWorldRanks(MpiComm const& comm)
{
#if ENABLE_MULTI_DEVICE
    MPI_Group group = nullptr;
    MPI_Group worldGroup = nullptr;

    MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
    MPICHECK(MPI_Comm_group(comm, &group));

    int groupSize = 0;
    MPICHECK(MPI_Group_size(group, &groupSize));
    std::vector<int> ranks(groupSize);
    std::vector<int> worldRanks(groupSize);
    std::iota(ranks.begin(), ranks.end(), 0);

    MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data()));
    MPICHECK(MPI_Group_free(&group));
    MPICHECK(MPI_Group_free(&worldGroup));
#else
    std::vector<int> worldRanks{0};
#endif
    return worldRanks;
}

void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
{
    // double-checked locking
    if (mpiInitialized)
    {
        return;
    }
    std::lock_guard<std::recursive_mutex> lk(mpiMutex);
    if (mpiInitialized)
    {
        return;
    }
#if ENABLE_MULTI_DEVICE
    int initialized = 0;
    TLLM_MPI_CHECK(MPI_Initialized(&initialized));
    if (!initialized)
    {
        TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode);
        int providedMode = 0;
        auto requiredMode = static_cast<int>(threadMode);
        MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode));
        TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed");
        std::atexit([]() { MPI_Finalize(); });

        /*
         * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2
         * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers
         * correctly.
         */
        for (int sig : {SIGABRT, SIGSEGV})
        {
            __sighandler_t previousHandler = nullptr;
            if (forwardAbortToParent)
            {
                previousHandler = std::signal(sig,
                    [](int signal)
                    {
#ifndef _WIN32
                        pid_t parentProcessId = getppid();
                        kill(parentProcessId, SIGKILL);
#endif
                        MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE);
                    });
            }
            else
            {
                previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); });
            }
            TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed");
        }

        // ensure local MPI communicator is initialized
        MpiComm::localSession();
        TLLM_LOG_INFO("Initialized MPI");
    }
#endif // ENABLE_MULTI_DEVICE
    mpiInitialized = true;
}

void MpiComm::barrier() const
{
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Barrier(mComm));
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}

#if ENABLE_MULTI_DEVICE
template <typename TMpiFunc, typename TBase, typename... TArgs,
    typename = std::enable_if_t<std::is_same_v<void, std::remove_const_t<TBase>>>>
size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args)
{
    constexpr auto maxP1 = static_cast<size_t>(std::numeric_limits<int>::max()) + 1;
    if (TLLM_LIKELY(size < maxP1))
    {
        MPICHECK(func(buffer, size, dtype, args...));
        return 1;
    }

    constexpr size_t alignment = 256;
    int elementSize = 1;
    MPICHECK(MPI_Type_size(dtype, &elementSize));
    elementSize = std::min<int>(elementSize, alignment);

    // We cap at max alignment-bytes chunks that can be sent at once.
    auto const step = maxP1 - (alignment / elementSize);

    using TCast = std::conditional_t<std::is_const_v<TBase>, uint8_t const, uint8_t>;
    size_t count = 0;
    while (size != 0)
    {
        auto currentStep = static_cast<int>(std::min(size, step));
        MPICHECK(func(buffer, currentStep, dtype, args...));
        size -= currentStep;
        size_t diff = static_cast<size_t>(currentStep) * elementSize;
        buffer = static_cast<TCast*>(buffer) + diff;
        ++count;
    }

    return count;
}
#endif // ENABLE_MULTI_DEVICE

std::shared_ptr<MpiRequest> MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const
{
    std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
    invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest);
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
    return r;
}

std::shared_ptr<MpiRequest> MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const
{
    TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU);
    return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}

void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
{
#if ENABLE_MULTI_DEVICE
    invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm);
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}

void MpiComm::bcast(runtime::IBuffer& buf, int root) const
{
    bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root);
}

std::shared_ptr<MpiRequest> MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
    TLLM_LOG_DEBUG("start MPI_Isend with size %d", size);
    std::shared_ptr<MpiRequest> r = std::make_shared<MpiRequest>();
#if ENABLE_MULTI_DEVICE
    invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest);
#else
    TLLM_THROW("Multi device support is disabled.");
#endif
    TLLM_LOG_DEBUG("end MPI_Isend with size %d", size);
    return r;
}

std::shared_ptr<MpiRequest> MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const
{
    return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}

void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{
    TLLM_LOG_DEBUG("start MPI_Send with size %d", size);
#if ENABLE_MULTI_DEVICE
    invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm);
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
    TLLM_LOG_DEBUG("end MPI_Send with size %d", size);
}

void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const
{
    send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag);
}

MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const
{
    TLLM_LOG_DEBUG("start MPI_Recv with size %d", size);
    MPI_Status status{};
#if ENABLE_MULTI_DEVICE
    invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status);
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
    TLLM_LOG_DEBUG("end MPI_Recv with size %d", size);
    return status;
}

MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
{
    return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag);
}

MpiComm MpiComm::split(int color, int key) const
{
    MPI_Comm splitComm = nullptr;
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm));
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
    return MpiComm{splitComm, true};
}

void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
{
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}

void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
{
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}

void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
    std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
{
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
        getMpiDtype(recvtype), mComm));

#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}

void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status));
#else
    TLLM_THROW("Multi device support is disabled.");
#endif // ENABLE_MULTI_DEVICE
}

bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
    int flag{0};
    MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status));
    return flag != 0;
#else
    TLLM_THROW("Multi device support is disabled.");
    return false;
#endif
}

bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
    int flag{0};
    MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status));
    return flag != 0;
#else
    TLLM_THROW("Multi device support is disabled.");
    return false;
#endif
}

void MpiComm::recvPoll(int source, int tag, int periodMs) const
{
    MPI_Status status;
    while (!iprobe(source, tag, &status))
    {
        std::this_thread::sleep_for(std::chrono::milliseconds(periodMs));
    }
}

int MpiComm::getRank() const
{
    int rank = 0;
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Comm_rank(mComm, &rank));
#endif
    return rank;
}

int MpiComm::getSize() const
{
    int world_size = 1;
#if ENABLE_MULTI_DEVICE
    MPICHECK(MPI_Comm_size(mComm, &world_size));
#endif
    return world_size;
}

MpiComm const& MpiComm::world()
{
    TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
    static MpiComm commWorld{MPI_COMM_WORLD, false};
    initialize();
    TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
    return commWorld;
}

MpiComm& MpiComm::mutableSession()
{
    TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
    static MpiComm commSession{MPI_COMM_WORLD, false};
    initialize();
    TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
    return commSession;
}

MpiComm& MpiComm::mutableLocalSession()
{
    TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
    static MpiComm localSession = initLocalSession();
    TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
    return localSession;
}

void MpiComm::refreshLocalSession()
{
#if ENABLE_MULTI_DEVICE
    static std::mutex mutex;
    std::unique_lock lock(mutex);
    auto initSessionRanks = getWorldRanks(MpiComm::session());
    auto localSessionRanks = getWorldRanks(MpiComm::localSession());

    // Add to intersectionRanks in order of initSessionRanks
    std::vector<int> intersectionRanks;
    std::unordered_set<int> localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end());
    for (auto rank : initSessionRanks)
    {
        if (localSessionRanksSet.find(rank) != localSessionRanksSet.end())
        {
            intersectionRanks.push_back(rank);
        }
    }

    MPI_Group worldGroup = nullptr;
    MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
    MPI_Group localGroup = nullptr;
    MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup));
    MPI_Comm localComm = nullptr;
    MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm));
    MpiComm::mutableLocalSession().mFreeComm = true;
    MpiComm::mutableLocalSession() = MpiComm{localComm, false};
    TLLM_LOG_INFO("Refreshed the MPI local session");
#endif // ENABLE_MULTI_DEVICE
}

MpiComm::MpiComm(MPI_Comm g, bool freeComm)
    : mComm{g}
    , mFreeComm{freeComm}
{
    TLLM_CHECK(mComm != MPI_COMM_NULL);
}

MpiComm::~MpiComm() noexcept
{
#if ENABLE_MULTI_DEVICE
    if (mFreeComm && mComm)
    {
        if (MPI_Comm_free(&mComm) != MPI_SUCCESS)
        {
            TLLM_LOG_ERROR("MPI_Comm_free failed");
        }
    }
#endif // ENABLE_MULTI_DEVICE
}

MpiComm::MpiComm(MpiComm&& comm) noexcept
    : mComm{comm.mComm}
    , mFreeComm{comm.mFreeComm}
{
    comm.mFreeComm = false;
}

MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept
{
    this->~MpiComm();
    mComm = comm.mComm;
    mFreeComm = comm.mFreeComm;
    comm.mFreeComm = false;
    return *this;
}

MpiWaitThread::MpiWaitThread(std::string name, std::function<void()> funcWait, std::function<void()> funcSetup)
    : mName{name.c_str()}
    , mFuncWait{funcWait}
    , mFuncSetup{funcSetup}
{
    TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
    mThread = std::make_unique<std::thread>(&MpiWaitThread::sideThread, this);
    TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}

MpiWaitThread::~MpiWaitThread()
{
    TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
    waitStop();
    mShouldExit.store(true);
    notifyStart();
    mThread->join();
    mThread.reset(nullptr);
    TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}

void MpiWaitThread::sideThread()
{
    if (mFuncSetup)
    {
        mFuncSetup();
    }
    while (!mShouldExit.load())
    {
        notifyStop();
        waitStart();
        mFuncWait();
    }
}

void MpiWaitThread::waitStart()
{
    TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
    std::unique_lock<std::mutex> lock(mMutex);
    mCondVar.wait(lock, [this] { return mRunning; });
    TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}

void MpiWaitThread::waitStop()
{
    TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
    std::unique_lock<std::mutex> lock(mMutex);
    mCondVar.wait(lock, [this] { return !mRunning; });
    TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}

void MpiWaitThread::notifyStart()
{
    TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
    std::lock_guard<std::mutex> lock(mMutex);
    mRunning = true;
    mCondVar.notify_one();
    TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}

void MpiWaitThread::notifyStop()
{
    TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__);
    std::lock_guard<std::mutex> lock(mMutex);
    mRunning = false;
    mCondVar.notify_one();
    TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__);
}

} // namespace tensorrt_llm::mpi