gather_tensor.hpp 7.01 KB
Newer Older
zhoux's avatar
zhoux committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
/***************************************************************************************************
 * Copyright (c) 2023 - 2025 Hygon Information Technology Co., Ltd. All rights reserved.
 * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
#pragma once

#include "hute/layout.hpp"
#include "hute/tensor.hpp"
#include "hute/util/print.hpp"

namespace example {

using namespace hute;

// Empty type used to disable gather/scatter for a GEMM argument
struct NoGather
{
  template<class... Ts>
  NoGather(Ts...) {};
};

/// Function object that applies an index to its argument
template <class Index>
struct IndexedGather
{
  HUTE_HOST_DEVICE constexpr
  IndexedGather(Index const *indices = {}): indices_(indices) {}

  template <typename I>
  HUTE_HOST_DEVICE constexpr
  Index
  operator()(I i) const { return indices_[i]; }

  HUTE_HOST_DEVICE friend
  void 
  print(IndexedGather const &s) {
    hute::print("Indexed");
  }

  Index const *indices_;
};

/// Function object that applies a stride to its argument
/// Example: StridedFunc<int,_2> gathers every other row/column
template <class Stride>
struct StridedGather
{
  HUTE_HOST_DEVICE constexpr
  StridedGather(Stride stride = {}): stride_(stride) {}

  template <class I>
  HUTE_HOST_DEVICE constexpr
  auto
  operator()(I i) const { return i * stride_; }

  HUTE_HOST_DEVICE friend
  void 
  print(StridedGather const &s) {
    hute::print("Strided{");
    print(s.stride_);
    hute::print("}");
  }

  Stride stride_;
};

/// Custom stride object that applies a function followed by a stride
template <class Func, class Stride>
struct CustomStride
{
  HUTE_HOST_DEVICE constexpr
  CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {}

  template <class I>
  HUTE_HOST_DEVICE constexpr friend
  auto
  operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; }

  template <class I>
  HUTE_HOST_DEVICE constexpr friend
  auto
  operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; }

  HUTE_HOST_DEVICE friend
  void
  print(CustomStride const & s) {
    hute::print("Custom{");
    print(s.func_);
    hute::print(",");
    print(s.stride_);
    hute::print("}");
  }

  template<class Div>
  HUTE_HOST_DEVICE constexpr friend
  auto
  safe_div(CustomStride const &s, Div const &div)
  {
    return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
  }

  // Circumvent the requirement on make_layout that shape and stride are integral
  template <class Shape>
  HUTE_HOST_DEVICE constexpr friend
  auto
  make_layout(Shape const &shape, CustomStride const &stride)
  {
    return Layout<Shape, CustomStride>(shape, stride);
  }

  Func func_;
  Stride stride_;
};

template<class Stride, class Func>
HYTLASS_HOST_DEVICE
auto
make_custom_stride_layout(Stride const &stride, Func&& func)
{
  // Use a dummy shape and replace the first non-unit stride with a custom gather stride
  auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; });
  constexpr int I = decltype(idx)::value;
  return make_layout(repeat_like(stride, _1{}),
                     replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
}

/// Helper function to optionally create a gather tensor
template<class Iterator, class Shape, class Stride, class Func>
HYTLASS_HOST_DEVICE
auto 
make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func)
{
  if constexpr (not hytlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {
    Layout matrix_layout = make_identity_layout(shape);
    auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
    Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
    return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
  } else {
    return make_tensor(iter, shape, stride);
  }
}

} // namespace example

namespace hute
{

template<int N, int I, class Shape, class Stride>
HUTE_HOST_DEVICE constexpr
auto
upcast(Shape const& shape, Stride const& stride)
{
  if constexpr (is_tuple<Shape>::value) {
    return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N,I>(s,d); });
  } else if constexpr (is_scaled_basis<Stride>::value) {
    if constexpr (Stride::mode() == I) {
      return make_layout(shape_div(shape, Int<N>{}), shape_div(stride, Int<N>{}));
    } else {
      return make_layout(shape, stride);
    }
  } else {
    return upcast<N>(shape, stride);
  }

  HUTE_GCC_UNREACHABLE;
}

template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
HUTE_HOST_DEVICE constexpr
auto
upcast(ComposedLayout<Layout<OuterShape,OuterStride>,Offset,Layout<Shape,Stride>> const& layout)
{
  // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
  auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; });
  constexpr int I = decltype(idx)::value;

  // Upcast the outer layout (works as expected)
  auto outer = upcast<N>(layout.layout_a());

  // Upcast the accumulated offset along stride-1 mode
  auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));

  // Upcast the inner layout's shape along stride-1 mode
  auto inner = upcast<N,I>(layout.layout_b().shape(), layout.layout_b().stride());

  return composition(outer, offset, inner);
}

} // namespace example