machete_pytorch.cu 3.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"

namespace machete {

using namespace vllm;

//
//  Utils (type dispatching)
//

template <typename Fn>
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
  if (type == vllm::kU4) {
    return fn(cutlass::uint4b_t{});
  } else if (type == vllm::kU8) {
    return fn(cutlass::uint8_t{});
  } else if (type == vllm::kU4B8) {
    return fn(cutlass::vllm_uint4b8_t{});
  } else if (type == vllm::kU8B128) {
    return fn(cutlass::vllm_uint8b128_t{});
  } else {
    TORCH_CHECK(false, "Unsupported type ", type.str());
  }
}

#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
  AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)

#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
  AT_DISPATCH_SWITCH(TYPE, NAME,                             \
                     AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))

//
//  Interface
//

std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
40
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
41
42
43
  return scalar_type_dispatch(*btype, [&](auto BType) {
    return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
  });
44
45
46
#else
  TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
47
48
49
50
51
52
53
54
55
56
}

torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
                   ScalarTypeTorchPtr const& btype,
                   c10::optional<torch::Tensor> const& scales,
                   c10::optional<torch::Tensor> const& zeros,
                   c10::optional<int64_t> group_size,
                   c10::optional<torch::Tensor> const& C,
                   c10::optional<double> alpha, c10::optional<double> beta,
                   c10::optional<std::string> schedule) {
57
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  auto args = PyTorchArguments{.A = A,
                               .B = B,
                               .scales = scales,
                               .zeros = zeros,
                               .group_size = group_size,
                               .C = C,
                               .alpha = alpha,
                               .beta = beta,
                               .schedule = schedule};

  return scalar_type_dispatch(*btype, [&](auto BType) {
    return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
        A.scalar_type(), "machete_gemm", [&] {
          using ComputeType = equivalent_cutlass_type_t<scalar_t>;
          return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
        });
  });
75
76
77
#else
  TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
78
79
80
81
}

torch::Tensor prepack_B(torch::Tensor const& B,
                        ScalarTypeTorchPtr const& btype) {
82
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
83
84
85
  return scalar_type_dispatch(*btype, [&](auto BType) {
    return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
  });
86
87
88
#else
  TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
89
90
91
}

};  // namespace machete