sparse_compressor_c3x.cu 6.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>

#include "sparse_scaled_mm_c3x.cuh"

#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"

#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
// clang-format on

using namespace cute;
using namespace vllm;

/// Make A structured sparse by replacing elements with 0 and compress it
template <typename ElementA_, typename ElementAcc_>
bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
                             torch::Tensor const& a) {
  // Checks for conformality
  TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
              a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
  TORCH_CHECK(a.dim() == 2)
  // Check for strides and alignment
  TORCH_CHECK(a.stride(0) % 4 == 0)  // Required for semi-structured sparsity
  TORCH_CHECK(a.stride(1) == 1)

  int m = a.size(0);
  int k = a.size(1);

  // Sparse kernel setup; this kernel is not used for matmul,
  // but just for setting up the compressor utility
  // A matrix configuration
  using ElementA = ElementA_;
  using LayoutTagA = cutlass::layout::RowMajor;
  constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
  // B matrix configuration
  using ElementB = ElementA;
  using LayoutTagB = cutlass::layout::ColumnMajor;
  constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
  // C/D matrix configuration
  using ElementC = float;
  using LayoutTagC = cutlass::layout::ColumnMajor;
  constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
  // Core kernel configurations
  using ElementAccumulator = ElementAcc_;
  using TileShape = Shape<_128, _128, _128>;
  using TileShapeRef = Shape<_128, _128, _64>;
  using ClusterShape = Shape<_1, _2, _1>;
  using KernelSchedule = typename std::conditional<
      std::is_same_v<ElementA, cutlass::float_e4m3_t>,
      cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
      cutlass::gemm::KernelTmaWarpSpecialized>::type;

  using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
  using ProblemShape = Shape<int, int, int, int>;

  using CollectiveEpilogue =
      typename cutlass::epilogue::collective::CollectiveBuilder<
          cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
          ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
          ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
          AlignmentC, ElementC, LayoutTagC, AlignmentC,
          EpilogueSchedule>::CollectiveOp;

  using CollectiveMainloop =
      typename cutlass::gemm::collective::CollectiveBuilder<
          cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
          LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
          ElementAccumulator, TileShape, ClusterShape,
          cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
              sizeof(typename CollectiveEpilogue::SharedStorage))>,
          KernelSchedule>::CollectiveOp;

  using GemmKernel =
      cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
                                           CollectiveEpilogue>;

  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

  using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
  using StrideE = StrideA;

  using StrideA = Stride<int64_t, Int<1>, int64_t>;

  // The n (=1) dimension does not matter for the compressor
  typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};

  using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
  using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;

  using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
  using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;

  // Offline compressor kernel
  using CompressorUtility =
      cutlass::transform::kernel::StructuredSparseCompressorUtility<
          ProblemShape, ElementA, LayoutTagA, SparseConfig>;

  using CompressorKernel =
      cutlass::transform::kernel::StructuredSparseCompressor<
          ProblemShape, ElementA, LayoutTagA, SparseConfig,
          cutlass::arch::Sm90>;

  using Compressor =
      cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;

  auto [M, N, K, L] = prob_shape;

  StrideA stride_A;
  stride_A =
      cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));

  CompressorUtility compressor_utility(prob_shape, stride_A);

  int ME = compressor_utility.get_metadata_m_physical();
  int KE = compressor_utility.get_metadata_k_physical();
  int KC = compressor_utility.get_tensorA_k_physical();

  auto a_ptr = static_cast<ElementA*>(a.data_ptr());

  auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
  auto a_meta_ptr = static_cast<typename Gemm::CollectiveMainloop::ElementE*>(
      a_meta.data_ptr());

  cutlass::KernelHardwareInfo hw_info;
  hw_info.device_id = 0;
  hw_info.sm_count =
      cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
          hw_info.device_id);
  typename Compressor::Arguments arguments{
      prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};

  Compressor compressor_op;
  size_t workspace_size = Compressor::get_workspace_size(arguments);
  cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

  CUTLASS_CHECK(compressor_op.can_implement(arguments));
  CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
  CUTLASS_CHECK(compressor_op.run());
  CUDA_CHECK(cudaDeviceSynchronize());

  return true;
}

bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
                                  torch::Tensor const& a) {
  if (a.dtype() == torch::kBFloat16) {
    return cutlass_sparse_compress<cutlass::bfloat16_t, float>(a_nzs, a_meta,
                                                               a);
  } else if (a.dtype() == torch::kFloat16) {
    return cutlass_sparse_compress<cutlass::half_t, float>(a_nzs, a_meta, a);
  } else if (a.dtype() == torch::kFloat8_e4m3fn) {
    return cutlass_sparse_compress<cutlass::float_e4m3_t, float>(a_nzs, a_meta,
                                                                 a);
  } else if (a.dtype() == torch::kInt8) {
    return cutlass_sparse_compress<int8_t, int32_t>(a_nzs, a_meta, a);
  }
  return false;
}