torch_utils.hpp 5.38 KB
Newer Older
1
2
#pragma once

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// This header is shared between _C (unstable ABI, used by machete) and
// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION
// is defined only for the stable target, so we switch includes and types
// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor.
#ifdef TORCH_TARGET_VERSION
  #include <torch/csrc/stable/tensor.h>
  #include <torch/headeronly/util/BFloat16.h>
  #include <torch/headeronly/util/Half.h>
  #include <torch/headeronly/util/shim_utils.h>  // for STD_TORCH_CHECK
using TorchTensor = torch::stable::Tensor;
  #define TORCH_UTILS_CHECK STD_TORCH_CHECK
#else
  #include <torch/all.h>
using TorchTensor = torch::Tensor;
  #define TORCH_UTILS_CHECK TORCH_CHECK
#endif
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"

using ColumnMajor = typename cutlass::layout::ColumnMajor;
using RowMajor = typename cutlass::layout::RowMajor;

namespace cute {

namespace detail {

template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
                                                seq<I...>) {
  return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}

template <class F, int... I>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
  return make_shape(f(I)...);
}

};  // namespace detail

template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
  if constexpr (cute::is_tuple<T>::value) {
    return detail::tapply_with_idx(
        t, f, [](auto const&... a) { return cute::make_tuple(a...); },
        tuple_seq<T>{});
  } else {
    return f(t);
  }

  CUTE_GCC_UNREACHABLE;
}

// calls: make_shape(f(0), f(1), ..., f(N-1))
template <int N, class F>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
  return detail::make_shape_from_idx(f, make_seq<N>{});
}

};  // namespace cute

// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
73
static inline auto make_cute_layout(TorchTensor const& tensor,
74
                                    std::string_view name = "tensor") {
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{}));
  auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele,
                                                       auto const& idx) {
    using StrideEle = std::decay_t<decltype(stride_ele)>;

    if (idx < tensor.dim()) {
      if constexpr (cute::is_static_v<StrideEle>) {
        TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
                          name, ".stride(", idx, ") to be ", StrideEle::value);
        return StrideEle{};
      } else {
        if (tensor.size(idx) == 1) {
          // use 0 stride for dim with size 1, this is easier for
          // cute/cutlass to optimize (helps the TMA code flatten dims)
          return StrideEle{0};
90
        } else {
91
          return tensor.stride(idx);
92
        }
93
94
95
96
97
98
99
100
101
      }
    } else {
      // Extra strides are assumed to be 0 or 1
      if constexpr (cute::is_static_v<StrideEle>) {
        static_assert(StrideEle::value == 0 || StrideEle::value == 1);
      }
      return StrideEle{};
    }
  });
102
103
104
105
106
107
108
109
110
111
112
113
114

  auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
    if (idx < tensor.dim())
      return tensor.size(idx);
    else
      return int64_t(1);
  });

  return make_layout(shape, stride);
}

template <typename Stride>
static inline auto maybe_make_cute_layout(
115
    std::optional<TorchTensor> const& tensor,
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    std::string_view name = "tensor") {
  using Layout = decltype(make_cute_layout<Stride>(*tensor));

  if (tensor) {
    return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
  } else {
    return std::optional<Layout>{};
  }
}

//
//  Torch Type to Cutlass Type (equivalent_cutlass_type)
//

template <typename T>
struct equivalent_cutlass_type {
  using type = T;
};

template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;

template <>
139
struct equivalent_cutlass_type<torch::headeronly::Half> {
140
141
142
143
  using type = cutlass::half_t;
};

template <>
144
struct equivalent_cutlass_type<torch::headeronly::BFloat16> {
145
146
147
148
149
150
151
  using type = cutlass::bfloat16_t;
};

//
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//

152
153
// Return a `torch::headeronly::CppTypeToScalarType<T>` compatible type, i.e.
// get the C++ type equivalent to T, e.g.: `cutlass::half_t -> Half`
154
155
156
157
158
159
160
161
162
163
template <typename T>
struct equivalent_scalar_type {
  using type = T;
};

template <typename T>
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;

template <>
struct equivalent_scalar_type<cutlass::half_t> {
164
  using type = torch::headeronly::Half;
165
166
167
168
};

template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
169
  using type = torch::headeronly::BFloat16;
170
171
};

172
// get equivalent torch::headeronly::ScalarType tag from compile time type
173
template <typename T>
174
175
static inline constexpr torch::headeronly::ScalarType equivalent_scalar_type_v =
    torch::headeronly::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;