machete_pytorch.cu 3.39 KB
Newer Older
1
2
3
4
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"

5
6
#include "core/registration.h"

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
namespace machete {

using namespace vllm;

//
//  Utils (type dispatching)
//

template <typename Fn>
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
  if (type == vllm::kU4) {
    return fn(cutlass::uint4b_t{});
  } else if (type == vllm::kU8) {
    return fn(cutlass::uint8_t{});
  } else if (type == vllm::kU4B8) {
    return fn(cutlass::vllm_uint4b8_t{});
  } else if (type == vllm::kU8B128) {
    return fn(cutlass::vllm_uint8b128_t{});
  } else {
    TORCH_CHECK(false, "Unsupported type ", type.str());
  }
}

#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
  AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)

#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
  AT_DISPATCH_SWITCH(TYPE, NAME,                             \
                     AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))

//
//  Interface
//

41
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
42
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
43
44
  vllm::ScalarType b_type = ScalarType::from_id(btype_id);
  return scalar_type_dispatch(b_type, [&](auto BType) {
45
46
    return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
  });
47
48
49
#else
  TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
50
51
52
}

torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
53
                   ScalarTypeId const btype_id,
54
55
56
57
58
59
                   c10::optional<torch::Tensor> const& scales,
                   c10::optional<torch::Tensor> const& zeros,
                   c10::optional<int64_t> group_size,
                   c10::optional<torch::Tensor> const& C,
                   c10::optional<double> alpha, c10::optional<double> beta,
                   c10::optional<std::string> schedule) {
60
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
61
  ScalarType const btype = ScalarType::from_id(btype_id);
62
63
64
65
66
67
68
69
70
71
  auto args = PyTorchArguments{.A = A,
                               .B = B,
                               .scales = scales,
                               .zeros = zeros,
                               .group_size = group_size,
                               .C = C,
                               .alpha = alpha,
                               .beta = beta,
                               .schedule = schedule};

72
  return scalar_type_dispatch(btype, [&](auto BType) {
73
74
75
76
77
78
    return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
        A.scalar_type(), "machete_gemm", [&] {
          using ComputeType = equivalent_cutlass_type_t<scalar_t>;
          return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
        });
  });
79
80
81
#else
  TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
82
83
}

84
85
86
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
  ScalarType const btype = ScalarType::from_id(btype_id);
  return scalar_type_dispatch(btype, [&](auto BType) {
87
88
    return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
  });
89
90
91
92
93
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
  m.impl("machete_prepack_B", &prepack_B);
  m.impl("machete_gemm", &gemm);
94
95
96
97
}

// use CatchAll since supported_schedules has no tensor arguments
TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) {
98
  m.impl("machete_supported_schedules", &supported_schedules);
99
100
101
}

};  // namespace machete