lower_thread_allreduce.cc 34.9 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 *  Lower allreduce to device implementable ir.
 * \file lower_thread_allreduce.cc
 */
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_set>
#include <utility>

#include "runtime/thread_storage_scope.h"
#include "tir/transforms/ir_utils.h"
#include "tir/transforms/update_pointer_storage_scope.h"

namespace tvm {
namespace tl {
using namespace tir;

using runtime::StorageRank;
using runtime::StorageScope;

/*!
 * \brief collect the mapping from the buffer var to its allocate
 */
class AllocateCollector : public StmtExprVisitor {

private:
  bool IsDynamicSharedMemory(Var buffer_var) {
    StorageScope storage_scope = runtime::StorageScope::Create(
        GetPtrStorageScope(std::move(buffer_var)));
    return storage_scope.rank == runtime::StorageRank::kShared &&
           storage_scope.tag == ".dyn";
  }

  bool IsStaticSharedMemory(Var buffer_var) {
    StorageScope storage_scope = runtime::StorageScope::Create(
        GetPtrStorageScope(std::move(buffer_var)));
    return storage_scope.rank == runtime::StorageRank::kShared &&
           storage_scope.tag.empty();
  }

public:
  void VisitStmt_(const AllocateNode *op) final {
    if (IsDynamicSharedMemory(op->buffer_var)) {
      dyn_shmem_allocs_[op->buffer_var.get()] = op;
    } else if (IsStaticSharedMemory(op->buffer_var)) {
      static_shmem_allocs_[op->buffer_var.get()] = op;
    }
    StmtExprVisitor::VisitStmt_(op);
  }
  // The dynamic mapping from the original buffer var to its allocate
  std::unordered_map<const VarNode *, const AllocateNode *> dyn_shmem_allocs_;
  // The static mapping from the original buffer var to its allocate
  std::unordered_map<const VarNode *, const AllocateNode *>
      static_shmem_allocs_;
};

class ThreadAllreduceBuilder final : public StmtExprMutator {
public:
  explicit ThreadAllreduceBuilder(const TargetNode *target,
                                  bool is_dynamic = false)
      : target_(target),
        warp_size_(
            target->GetAttr<Integer>("thread_warp_size", 1).value().IntValue()),
        max_num_threads_(target->GetAttr<Integer>("max_num_threads", -1)
                             .value()
                             .IntValue()) {
    if (is_dynamic) {
      shared_scope = "shared.dyn";
    }
  }

  Stmt VisitStmt_(const AttrStmtNode *op) final {
    if (op->attr_key == tir::attr::thread_extent) {
      thread_extents_.push_back(op);
      Stmt ret = StmtExprMutator::VisitStmt_(op);
      thread_extents_.pop_back();
      return ret;
    } else if (op->attr_key == tir::attr::reduce_scope) {
      const CommReducerNode *combiner = op->node.as<CommReducerNode>();
      ICHECK(combiner);
      reduce_combiner_.push_back(combiner);
      Stmt ret = StmtExprMutator::VisitStmt_(op);
      reduce_combiner_.pop_back();
      return ret;
    } else {
      return StmtExprMutator::VisitStmt_(op);
    }
  }
  Stmt VisitStmt_(const EvaluateNode *op) final {
    Stmt stmt = StmtExprMutator::VisitStmt_(op);
    op = stmt.as<EvaluateNode>();
    const CallNode *call = op->value.as<CallNode>();
    if (call && call->op.same_as(builtin::tvm_thread_allreduce())) {
      return MakeAllreduce(call);
    } else {
      return stmt;
    }
  }
  Stmt VisitStmt_(const AllocateNode *op) final {
    auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));

    if (auto it = alloc_remap_.find(node->buffer_var.get());
        it != alloc_remap_.end()) {
      Buffer buf = Downcast<Buffer>(it->second);
      auto write_ptr = node.CopyOnWrite();
      write_ptr->buffer_var = buf->data;
      write_ptr->dtype = buf->dtype;
      write_ptr->extents = buf->shape;
      write_ptr->condition = const_true(buf->dtype.lanes());

      if (buf.scope() == shared_scope) {
        // Use volatile access to shared buffer.
        write_ptr->body =
            AttrStmt(buf->data, tir::attr::volatile_scope, 1, write_ptr->body);
      }
    }
    return std::move(node);
  }

  Optional<Buffer> GetRemappedBuffer(const Buffer &buf) {
    if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) {
      return it->second;
    }

    if (auto it = var_remap_.find(buf->data.get()); it != var_remap_.end()) {
      Buffer new_buf = buf;
      new_buf.CopyOnWrite()->data = it->second;
      buf_remap_[buf.get()] = new_buf;
      return new_buf;
    }

    return std::nullopt;
  }

  Stmt VisitStmt_(const DeclBufferNode *op) final {
    auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
    if (auto buf = GetRemappedBuffer(node->buffer)) {
      node.CopyOnWrite()->buffer = buf.value();
    }
    return std::move(node);
  }

  PrimExpr VisitExpr_(const BufferLoadNode *op) final {
    if (auto it = load_remap_.find(op->buffer->data.get());
        it != load_remap_.end()) {
      for (const auto &index : op->indices) {
        ICHECK(is_zero(index))
            << "The index of buffer " << op->buffer << " is " << index;
      }
      return it->second;
    }

    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));

    if (auto opt = GetRemappedBuffer(load->buffer)) {
      load.CopyOnWrite()->buffer = opt.value();
    }
    return std::move(load);
  }

  Stmt VisitStmt_(const BufferStoreNode *op) final {
    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));

    if (auto opt = GetRemappedBuffer(store->buffer)) {
      store.CopyOnWrite()->buffer = opt.value();
    }
    return std::move(store);
  }

private:
  // Thread entry
  struct ThreadEntry {
    runtime::ThreadScope scope;
    IterVar iv;
    int extent{};
    // comparator
    bool operator<(const ThreadEntry &other) const {
      return scope.dim_index < other.scope.dim_index;
    }
  };

  // make allreduce.
  Stmt MakeAllreduce(const CallNode *call) {
    ICHECK(!reduce_combiner_.empty());
    const CommReducerNode *combiner = reduce_combiner_.back();
    size_t size = combiner->result.size();

    const IntImmNode *size_of_args = call->args[0].as<IntImmNode>();
    ICHECK(size_of_args) << call->args[0]->GetTypeKey();
    ICHECK_EQ(size, size_of_args->value);
    Array<PrimExpr> inits = combiner->identity_element;
    std::vector<PrimExpr> values(size);
    std::vector<DataType> types(size);
    PrimExpr cond = call->args[size + 1];
    for (size_t idx = 0; idx < size; ++idx) {
      values[idx] = call->args[1 + idx];
      if (!is_one(cond)) {
        values[idx] = Select(cond, values[idx], inits[idx]);
      }
      types[idx] = values[idx].dtype();
    }
    std::vector<Buffer> buffers(size);
    for (size_t idx = 0; idx < size; ++idx) {
      PrimExpr arg = call->args[2 + size + idx];
      // Loads from boolean buffers may have cast nodes inserted by
      // earlier passes.
      if (auto cast = arg.as<CastNode>()) {
        arg = cast->value;
      }
      buffers[idx] = Downcast<BufferLoad>(arg)->buffer;
    }

    std::unordered_set<const VarNode *> reduce_set;
    for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) {
      const VarNode *v = call->args[i].as<VarNode>();
      // The simply optimization replace a iteration variable with a constant
      // when extent of the iteration is 1. As threaded IterVar always started
      // from 0, we can just ignore this variable in this case.
      if (v) {
        reduce_set.insert(v);
      } else {
        ICHECK(call->args[i].as<IntImmNode>() &&
               call->args[i].as<IntImmNode>()->value == 0)
            << "arg" << i << "should be a VarNode or IntImmNode "
            << "while it is " << call->args[i];
      }
    }

    size_t nmatch = 0;
    std::vector<ThreadEntry> vred, vpar;
    int reduce_dim_index = -1;
    for (const AttrStmtNode *attr : thread_extents_) {
      ThreadEntry e;
      IterVar iv = Downcast<IterVar>(attr->node);
      e.scope = runtime::ThreadScope::Create(iv->thread_tag);
      e.iv = iv;
      ICHECK_LE(e.scope.rank, 1);
      ICHECK_GE(e.scope.dim_index, 0)
          << "vthread do not work with cross thread reduction";
      if (e.scope.rank == 1) {
        const auto *ptr = attr->value.as<IntImmNode>();
        ICHECK(ptr) << "Need constant extent for reduce set " << iv;
        e.extent = static_cast<int>(ptr->value);
        // ignore variables equal to 0
        if (e.extent == 1) {
          continue;
        }

        if (reduce_set.count(iv->var.get())) {
          bool already_exists = false;
          for (const auto &entry : vred) {
            if (entry.scope.dim_index == e.scope.dim_index) {
              already_exists = true;
              break;
            }
          }
          if (!already_exists) {
            vred.push_back(e);
            ++nmatch;
            reduce_dim_index = e.scope.dim_index;
          }
        } else {
          bool already_exists = false;
          for (const auto &entry : vpar) {
            if (entry.scope.dim_index == e.scope.dim_index) {
              already_exists = true;
              break;
            }
          }
          if (!already_exists) {
            vpar.push_back(e);
          }
        }
      }
    }

    // remove reduce thread from parallel thread
    if (reduce_dim_index != -1) {
      for (size_t i = 0; i < vpar.size(); ++i) {
        if (vpar[i].scope.dim_index == reduce_dim_index) {
          vpar.erase(vpar.begin() + i);
          break;
        }
      }
    }

    ICHECK_EQ(nmatch, reduce_set.size())
        << "Not all reduce index are presented in the context";
    std::sort(vred.begin(), vred.end());
    std::sort(vpar.begin(), vpar.end());
    // the size of each index.
    int reduce_extent, group_extent;
    PrimExpr reduce_index = FlattenThread(vred, &reduce_extent);
    PrimExpr group_index = FlattenThread(vpar, &group_extent);

    // the longest contiguous reduce extent after flattening
    int contiguous_reduce_extent = 1;
    std::vector<std::tuple<int, int, bool>>
        block_threads; // tuple(dim_index, extent, is_reduce)
    for (const ThreadEntry &thr : vred) {
      if (thr.scope.rank == 1) { // threadIdx
        block_threads.emplace_back(thr.scope.dim_index, thr.extent, true);
      }
    }
    for (const ThreadEntry &thr : vpar) {
      if (thr.scope.rank == 1) { // threadIdx
        block_threads.emplace_back(thr.scope.dim_index, thr.extent, false);
      }
    }
    // sort according to dim_index
    std::sort(block_threads.begin(), block_threads.end());
    for (auto &&thr_attr : block_threads) {
      auto [dim_index, extent, is_reduce] = thr_attr;
      (void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
      if (is_reduce) {
        contiguous_reduce_extent *= extent;
      } else {
        break;
      }
    }

    std::vector<Stmt> seq;
    std::vector<Buffer> new_alloc_bufs;
    //
    // This is an optimization. For small reduction sizes, it may be beneficial
    // for a single warp to performance the entire reduction. No trips to shared
    // memory and no cross warp synchronizations are required.
    // The following code emits the reduction as follows:
    //
    // Allocate reduction vars v[i], i = 0..size-1
    //
    // for offset from WARP_SIZE to 1 by 2
    //
    //   a    <- load(v[i])
    //   b    <- shuffle_down(load(v[i], offset))
    //   v[i] <- reduction(a, b)
    //
    // broadcast results from lane 0 to all other lanes and store
    // the final reduction result to the proper location.
    //
    // When the thread extent is multiple of warp size, we can use a two-stage
    // warp-level reduction to optimize. This is implemented by applying the
    // algorithm above twice.
    //
    // For example, suppose we want to use 512 threads to reduce 512 elements
    // and the warp size is 32. In this case there are (512 / 32) = 16 warps.
    // In the first stage, each of the 16 warps reduces 32 elements. So after
    // the stage, we have 16 remaining elements to be reduced, one for each
    // warp. We store the 16 elements in shared memory, and start the second
    // stage. In the second stage we use the first 16 lanes of the first warp to
    // reduce the remaining elements, and this reduction can also be optimized
    // by shuffle_down warp-level primitives.
    PrimExpr zero_index = make_const(reduce_index->dtype, 0);

    if (IsWarpReduction(types, group_extent, reduce_extent,
                        contiguous_reduce_extent)) {
      std::vector<PrimExpr> reduce_results;
      DataType mask_dtype = DataType::UInt(32);
      PrimExpr mask = Call(mask_dtype, builtin::tvm_warp_activemask(), {});

      if (reduce_extent <= warp_size_) {
        std::tie(reduce_results, new_alloc_bufs) = MakeWarpAllreduce(
            values, types, combiner, reduce_index, reduce_extent, group_index,
            mask, std::nullopt, &seq);

        // Broadcast the reduction result from lane 0 to all other lanes.
        // This avoids to emit predicated stores, as all threads are
        // uniformly writing the same result.
        for (size_t i = 0; i < size; ++i) {
          Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
          PrimExpr val = BufferLoad(buf, {zero_index});
          ICHECK_EQ(val->dtype, types[i]);
          PrimExpr splat =
              WarpShuffle(builtin::tvm_warp_shuffle(), new_alloc_bufs.back(),
                          val, reduce_extent * group_index);
          seq.push_back(BufferStore(buf, splat, {zero_index}));
        }
      } else {
        int n_warps = reduce_extent / warp_size_;
        std::vector<Buffer> local_bufs;

        // 1. Create the staging buffer in shared memory.
        std::vector<Buffer> staging_shared_bufs;
        staging_shared_bufs.reserve(size);
        for (size_t i = 0; i < size; ++i) {
          Buffer staging_shared_buf = decl_buffer(
              /*shape=*/{make_const(reduce_index->dtype,
                                    n_warps * group_extent)},
              /*dtype=*/buffers[i]->dtype, /*name=*/"red_buf_staging",
              /*storage_scope=*/shared_scope);
          staging_shared_bufs.push_back(staging_shared_buf);
          new_alloc_bufs.push_back(staging_shared_buf);
        }

        // 2. First round of allreduce.
        std::tie(reduce_results, local_bufs) =
            MakeWarpAllreduce(values, types, combiner, reduce_index, warp_size_,
                              group_index, mask, std::nullopt, &seq);
        new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
                              local_bufs.end());

        // 3. Write allreduce results to staging buffer.
        std::vector<Stmt> write_staging_buf;
        write_staging_buf.reserve(size);
        for (size_t i = 0; i < size; ++i) {
          new_alloc_bufs.push_back(
              Downcast<BufferLoad>(reduce_results[i])->buffer);
          write_staging_buf.push_back(BufferStore(
              /*buffer=*/staging_shared_bufs[i],
              /*value=*/reduce_results[i],
              /*indices=*/
              {group_index * n_warps + floordiv(reduce_index, warp_size_)}));
        }
        PrimExpr cond = floormod(reduce_index, warp_size_) == zero_index;
        seq.push_back(IfThenElse(cond, SeqStmt::Flatten(write_staging_buf)));
        seq.push_back(SyncThread(shared_scope));

        // 4. Load staging buffer.
        //    Second round of allreduce.
        for (size_t i = 0; i < size; ++i) {
          values[i] =
              BufferLoad(/*buffer=*/staging_shared_bufs[i],
                         /*indices=*/{group_index * n_warps + reduce_index});
        }
        std::tie(reduce_results, local_bufs) = MakeWarpAllreduce(
            values, types, combiner, reduce_index, n_warps, group_index, mask,
            /*predicate=*/reduce_index <
                make_const(reduce_index->dtype, n_warps),
            &seq);
        new_alloc_bufs.insert(new_alloc_bufs.end(), local_bufs.begin(),
                              local_bufs.end());

        // 5. Create shared memory buffer(s) of `group_extent` elements, storing
        // the allreduce results so each thread can access.
        std::vector<Stmt> write_result;
        write_result.reserve(size);
        for (size_t i = 0; i < size; ++i) {
          new_alloc_bufs.push_back(
              Downcast<BufferLoad>(reduce_results[i])->buffer);
          Buffer broadcast_shared_buf = decl_buffer(
              /*shape=*/{make_const(reduce_index->dtype, group_extent)},
              /*dtype=*/buffers[i]->dtype, /*name=*/"red_result",
              /*storage_scope=*/shared_scope);
          write_result.push_back(BufferStore(broadcast_shared_buf,
                                             reduce_results[i], {group_index}));
          // Update `reduce_results`, pointing to the value loaded from the
          // shared memory buffer.
          reduce_results[i] = BufferLoad(broadcast_shared_buf, {group_index});
        }
        seq.push_back(IfThenElse(reduce_index == zero_index,
                                 SeqStmt::Flatten(write_result)));
        seq.push_back(SyncThread(shared_scope));
      }

      // Write back allreduce results and update existing allocations.
      for (size_t i = 0; i < size; ++i) {
        ICHECK(!load_remap_.count(buffers[i]->data.get()));
        PrimExpr pred = const_true(types[i].lanes());
        Buffer buf = Downcast<BufferLoad>(reduce_results[i])->buffer;
        ICHECK_EQ(reduce_results[i]->dtype, types[i]);
        load_remap_[buffers[i]->data.get()] = reduce_results[i];

        auto node =
            Allocate(buf->data, types[i], buf->shape, pred, Evaluate(0));
        alloc_remap_[buffers[i]->data.get()] = buf;
        var_remap_[buffers[i]->data.get()] = buf->data;
        buf_remap_[buffers[i].get()] = buf;
      }
    } else {
      std::vector<Buffer> shared_bufs(size);
      if (reduce_extent == 1) {
        // special case, no reduction is needed.
        std::vector<Stmt> stores;
        stores.reserve(size);
        for (size_t i = 0; i < size; ++i) {
          stores.emplace_back(BufferStore(buffers[i], values[i], {0}));
        }
        return SeqStmt::Flatten(stores);
      }
      // This sync is necessary because there might be incomplete read of
      // previous iteration on the same buffer.
      seq.emplace_back(SyncThread(shared_scope));
      for (size_t idx = 0; idx < size; ++idx) {
        shared_bufs[idx] = decl_buffer(
            {IntImm(group_index->dtype, group_extent * reduce_extent)},
            types[idx], "red_buf" + std::to_string(idx), shared_scope);
        seq.emplace_back(
            BufferStore(shared_bufs[idx], values[idx],
                        {BufIndex(reduce_index, group_index, reduce_extent)}));
      }
      seq.emplace_back(SyncThread(shared_scope));
      seq.emplace_back(MakeBufAllreduce(
          combiner, types, shared_bufs, reduce_index, group_index,
          reduce_extent, group_extent, contiguous_reduce_extent));
      for (size_t idx = 0; idx < size; ++idx) {
        ICHECK(!load_remap_.count(buffers[idx]->data.get()));
        PrimExpr pred = const_true(types[idx].lanes());
        BufferLoad load(shared_bufs[idx],
                        {BufIndex(make_zero(reduce_index.dtype()), group_index,
                                  reduce_extent)});
        ICHECK_EQ(load->dtype, types[idx]);
        load_remap_[buffers[idx]->data.get()] = load;
        alloc_remap_[buffers[idx]->data.get()] = shared_bufs[idx];
        var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
        buf_remap_[buffers[idx].get()] = shared_bufs[idx];
      }
    }

    // Fix all local allocations as all statements are built.
    Stmt body = SeqStmt::Flatten(seq);
    for (const Buffer &buf : new_alloc_bufs) {
      body = DeclBuffer(buf, body);
      body = Allocate(buf->data, buf->dtype, buf->shape,
                      const_true(buf->dtype.lanes()), body);
    }

    return body;
  }

  std::pair<std::vector<PrimExpr>, std::vector<Buffer>>
  MakeWarpAllreduce(std::vector<PrimExpr> src_values,                //
                    std::vector<DataType> dtypes,                    //
                    const CommReducerNode *combiner,                 //
                    const PrimExpr &reduce_index, int reduce_extent, //
                    const PrimExpr &group_index,                     //
                    const PrimExpr &mask,
                    const Optional<PrimExpr> &predicate, //
                    std::vector<Stmt> *seq) {
    int n_buffers = src_values.size();

    std::vector<Buffer> shared_bufs;
    std::vector<Buffer> local_bufs;
    shared_bufs.reserve(n_buffers);

    // This is the index to the reduction variable, one reduction
    // variable per warp. Local scope seems easier to reason without
    // relying on a pattern match pass to fix it later.
    Array<PrimExpr> zero_indices = {0};
    Array<PrimExpr> shape = {1};

    std::vector<Stmt> load_values;
    load_values.reserve(n_buffers);
    for (int idx = 0; idx < n_buffers; ++idx) {
      shared_bufs.push_back(decl_buffer(
          shape, dtypes[idx], "red_buf" + std::to_string(idx), "local"));
      load_values.push_back(
          BufferStore(shared_bufs[idx], src_values[idx], zero_indices));

      // Uses a local variable to store the shuffled data.  Later
      // on, an allocation will be built for this local variable.
      local_bufs.push_back(
          decl_buffer(shape, dtypes[idx], "t" + std::to_string(idx), "local"));
    }

    if (predicate.defined()) {
      seq->push_back(
          IfThenElse(predicate.value(), SeqStmt::Flatten(load_values)));
    } else {
      seq->insert(seq->end(), load_values.begin(), load_values.end());
    }

    // The mask for this reducer, as this reducer may sit inside
    // a divergent control flow. Here it uses a variable to cache the current
    // active channels.
    Optional<Buffer> mask_buffer;
    if (need_warp_shuffle_mask_) {
      mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local");
      seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices));
      // Push the buffer description.  Later this will have an
      // allocation built for it.
      local_bufs.push_back(mask_buffer.value());
    }

    // Emit reductions within a warp.
    int start_offset = 1;
    while (start_offset * 2 < reduce_extent) {
      start_offset *= 2;
    }
    for (int offset = start_offset; offset > 0; offset /= 2) {
      // Load reduction values, no synchronization needed.
      Array<PrimExpr> a, b;
      for (int i = 0; i < n_buffers; ++i) {
        const Buffer &shared_buf = shared_bufs[i];
        BufferLoad val(shared_buf, zero_indices);
        ICHECK_EQ(val->dtype, dtypes[i]);
        a.push_back(val);

        // __shfl_*sync calls shall not appear in if_then_else expressions
        // as this is causing extra divergency. E.g.
        //
        // v1 = (v2 < v3) ? v3 : __shfl_sync(mask, v1, 0);
        //
        // behaves differently from
        //
        // int t = __shfl_sync(mask, v1, 0);
        // v1 = (v2 < v3) ? v3 : t;
        //
        // The former may cause dead lock as there is a divergent
        // branch with a warp sync call inside.
        PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(),
                                     mask_buffer, val, offset);
        const Buffer &local_buf = local_bufs[i];
        Stmt s = BufferStore(local_buf, other, zero_indices);
        seq->push_back(s);

        BufferLoad load = BufferLoad(local_buf, zero_indices);
        ICHECK_EQ(load->dtype, dtypes[i]);
        b.push_back(load);
      }

      // Do reductions.
      Array<PrimExpr> ret = (*combiner)(a, b);

      // Store the reduction result to itself.
      std::vector<Stmt> stores;
      stores.reserve(n_buffers);
      for (int i = 0; i < n_buffers; ++i) {
        const Buffer &buf = shared_bufs[i];
        stores.push_back(BufferStore(buf, ret[i], zero_indices));
      }

      // During the sub-warp reduction, values from inactive threads could be
      // read, which is an undefined behavior according to the cuda document.
      //
      // In practice, the return value are usually 0, which does no harm to sum
      // reduction. However, the result can be incorrect in max or prod
      // reduction. Therefore an additional range check has to be performed to
      // ensure the correctness.
      if (offset * 2 > reduce_extent) {
        PrimExpr cond = reduce_index + offset < reduce_extent;
        seq->push_back(IfThenElse(cond, SeqStmt::Flatten(stores)));
      } else {
        seq->push_back(SeqStmt::Flatten(stores));
      }
    }

    std::vector<PrimExpr> reduce_results;
    reduce_results.reserve(n_buffers);
    for (int i = 0; i < n_buffers; ++i) {
      reduce_results.push_back(BufferLoad(shared_bufs[i], zero_indices));
    }

    return {reduce_results, local_bufs};
  }

  // make allreduce.
  Stmt MakeBufAllreduce(const CommReducerNode *combiner,
                        const std::vector<DataType> &types,
                        const Array<Buffer> &shared_bufs, PrimExpr reduce_index,
                        PrimExpr group_index, int reduce_extent,
                        int group_extent, int contiguous_reduce_extent) {
    // Get next power of two
    int reduce_align = 1;
    while (reduce_extent > reduce_align) {
      reduce_align = reduce_align << 1;
    }
    ICHECK_GT(reduce_align, 1);
    std::vector<Stmt> seq;

    size_t size = shared_bufs.size();
    PrimExpr buf_index = BufIndex(reduce_index, group_index, reduce_extent);
    // make reduction
    auto fload = [&](int offset) {
      Array<PrimExpr> a, b;
      for (size_t i = 0; i < size; ++i) {
        BufferLoad b_load(
            shared_bufs[i],
            {BufIndex(reduce_index + offset, group_index, reduce_extent)});
        ICHECK_EQ(b_load->dtype, types[i]);
        b.push_back(b_load);

        BufferLoad a_load(shared_bufs[i], {buf_index});
        ICHECK_EQ(a_load->dtype, types[i]);
        a.push_back(a_load);
      }
      Array<PrimExpr> ret = (*combiner)(a, b);
      return ret;
    };
    auto fstore = [&](const Array<PrimExpr> &ret) {
      std::vector<Stmt> stores(size);
      for (size_t i = 0; i < size; ++i) {
        stores[i] = BufferStore(shared_bufs[i], ret[i], {buf_index});
      }
      return SeqStmt::Flatten(stores);
    };
    auto freduce = [&](int offset) {
      auto ret = fload(offset);
      return fstore(ret);
    };
    // Step one, check for
    if (reduce_align > reduce_extent) {
      // reduction with the boundary condition
      reduce_align = reduce_align >> 1;
      PrimExpr cond = reduce_index < (reduce_extent - reduce_align);
      seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
      seq.emplace_back(SyncThread(shared_scope));
    }

    // normal synchronization
    bool warp_align =
        group_extent == 1 || contiguous_reduce_extent % warp_size_ == 0;
    while (reduce_align > contiguous_reduce_extent ||
           reduce_align > warp_size_ || !warp_align) {
      if (reduce_align == 1) {
        break;
      }
      reduce_align = reduce_align >> 1;
      PrimExpr cond = reduce_index < reduce_align;
      seq.emplace_back(IfThenElse(cond, freduce(reduce_align)));
      seq.emplace_back(SyncThread(shared_scope));
    }
    // in warp synchronization.
    if (reduce_align > 1) {
      PrimExpr in_warp_cond = reduce_index < (reduce_align >> 1);

      std::vector<Stmt> in_warp_seq;

      while (reduce_align > 1) {
        reduce_align = reduce_align >> 1;

        // freduce can read/write to the same memory location.  For
        // example, with reduce_align of 4, threadIdx 3 reads from
        // memory location 7 as threadIdx 7 is writing to it.
        // Therefore, we need to separate out the load from the store
        // with a memory barrier in-between.  This isn't necessary for
        // the earlier normal synchronization, because those are each
        // protected by an if-statement.  The if-statement is avoided
        // here to reduce thread divergence.
        auto loads = fload(reduce_align);

        Array<Var> in_warp_local_vars;
        for (auto expr : loads) {
          Var var("w_" + std::to_string(reduce_align) + "_" +
                      std::to_string(in_warp_local_vars.size()),
                  expr->dtype);
          in_warp_local_vars.push_back(var);
        }

        std::vector<Stmt> in_let_statement;
        in_let_statement.emplace_back(SyncThread("warp"));
        in_let_statement.emplace_back(
            fstore({in_warp_local_vars.begin(), in_warp_local_vars.end()}));
        in_let_statement.emplace_back(SyncThread("warp"));

        Stmt body = SeqStmt::Flatten(in_let_statement);
        for (size_t i = 0; i < size; i++) {
          body = LetStmt(in_warp_local_vars[i], loads[i], body);
        }
        in_warp_seq.push_back(body);
      }

      Stmt warp_body = SeqStmt::Flatten(in_warp_seq);

      seq.emplace_back(IfThenElse(in_warp_cond, warp_body));
      seq.emplace_back(SyncThread(shared_scope));
    }
    return SeqStmt::Flatten(seq);
  }
  // Flatten the thread index.
  // Also return a warp number,
  PrimExpr FlattenThread(const std::vector<ThreadEntry> &tvec,
                         int *out_total_extent) {
    int &total_extent = *out_total_extent;
    total_extent = 1;
    if (tvec.empty()) {
      return make_zero(DataType::Int(32));
    }

    PrimExpr ret;
    for (const ThreadEntry &e : tvec) {
      if (ret.defined()) {
        ret = ret + e.iv->var * total_extent;
      } else {
        ICHECK_EQ(total_extent, 1);
        ret = e.iv->var;
      }
      total_extent *= e.extent;
    }
    return ret;
  }
  // The local buffer index.
  PrimExpr BufIndex(PrimExpr reduce_index, const PrimExpr &group_index,
                    int reduce_extent) {
    if (!is_zero(group_index)) {
      return analyzer_.Simplify(group_index * reduce_extent + reduce_index);
    } else {
      return reduce_index;
    }
  }
  // sync thread op.
  static Stmt SyncThread(const std::string &sync) {
    return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
                         {StringImm(sync)}));
  }

  // Emit warp shuffle  calls.
  PrimExpr WarpShuffle(const Op &op, const Optional<Buffer> &mask_buffer,
                       const PrimExpr &val, PrimExpr delta_or_lane) {
    Array<PrimExpr> indices = {0};
    PrimExpr mask;
    if (mask_buffer.defined()) {
      mask = BufferLoad(mask_buffer.value(), indices);
    } else {
      mask = IntImm(DataType::Int(32), 0);
    }
    PrimExpr width = IntImm(DataType::Int(32), warp_size_);
    Array<PrimExpr> args{mask, val, std::move(delta_or_lane), width, width};
    return Call(val.dtype(), op, args);
  }

  // Check if we can use warp level reduction.
  //
  // Note: The ROCm backend will only have warp reductions for now.
  // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal).
  bool IsWarpReduction(const std::vector<DataType> &types, int group_extent,
                       int reduce_extent, int contiguous_reduce_extent) {
    if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
        (target_->kind->name != "metal")) {
      return false;
    }

    need_warp_shuffle_mask_ = target_->kind->name != "metal";

    // rocm only supports 32 bit operands for shuffling at the moment
    if ((target_->kind->name == "rocm") &&
        (std::any_of(types.begin(), types.end(), [](DataType ty) {
          if (ty.is_fixed_length_vector())
            return ty.bits() * ty.lanes() != 32;
          return ty.bits() != 32;
        }))) {
      return false;
    }

    // Supported types:
    // {u}int, {u}long, {u}long long, float, double, half/half2
    if (std::any_of(types.begin(), types.end(), [](DataType ty) {
          if (ty.is_float16())
            return ty.lanes() > 2;
          if (ty.is_fixed_length_vector())
            return true;
          return ty.bytes() < 4 || ty.bytes() > 8;
        })) {
      return false;
    }
    if (thread_extents_.empty()) {
      return false;
    }

    // reduce region must be contiguous.
    if (contiguous_reduce_extent != reduce_extent) {
      return false;
    }

    // whether reduce_extent and group_extent are valid for warp reduction.
    if (target_->kind->name == "rocm") {
      return reduce_extent == warp_size_;
    } else {
      if (reduce_extent == 1) {
        return false; // no need to warp reduce
      } else {
        bool is_subwarp_reduction = warp_size_ % reduce_extent == 0;
        bool is_multiwarp_reduction =
            max_num_threads_ != -1 &&
            max_num_threads_ <= warp_size_ * warp_size_ &&
            reduce_extent % warp_size_ == 0;
        if (is_subwarp_reduction || is_multiwarp_reduction) {
          return true;
        } else {
          return group_extent == 1 && reduce_extent <= warp_size_;
        }
      }
    }
  }

  // The target.
  const TargetNode *target_ = nullptr;
  // The shared scope.
  String shared_scope = "shared";
  // The warp size of the device.
  int warp_size_{1};
  // The maximum number of threads of the device. "-1" denotes unknown.
  int max_num_threads_{-1};
  // A boolean indicating if the target supports warp-level masking.
  bool need_warp_shuffle_mask_{};

  // surrounding scope of thread extent.
  std::vector<const AttrStmtNode *> thread_extents_;
  std::vector<const CommReducerNode *> reduce_combiner_;
  // The load remap
  std::unordered_map<const VarNode *, PrimExpr> load_remap_;
  // Allocate remap
  std::unordered_map<const VarNode *, Buffer> alloc_remap_;
  // BufferVar remap
  std::unordered_map<const VarNode *, Var> var_remap_;
  // Buffer remap
  std::unordered_map<const BufferNode *, Buffer> buf_remap_;
  // Internal analyzer
  arith::Analyzer analyzer_;
};

namespace transform {
using namespace tir::transform;

tvm::transform::Pass LowerThreadAllreduce() {
  auto pass_func = [](PrimFunc f, const IRModule &m, const PassContext &ctx) {
    AllocateCollector collector;
    collector(f->body);
    bool is_dynamic = collector.dyn_shmem_allocs_.size() > 1;

    auto *n = f.CopyOnWrite();
    auto target = f->GetAttr<Target>(tvm::attr::kTarget);
    ICHECK(target.defined())
        << "LowerThreadAllreduce: Require the target attribute";
    const TargetNode *target_node = target.as<TargetNode>();
    ThreadAllreduceBuilder thread_all_reduce(target_node, is_dynamic);
    n->body = thread_all_reduce(n->body);
    return f;
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {});
}

TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.LowerThreadAllreduce",
                        LowerThreadAllreduce);
});

} // namespace transform
} // namespace tl
} // namespace tvm