marlin.cuh 2.36 KB
Newer Older
1
2
#pragma once

3
#include <torch/all.h>
4
5
6
7
8
9
10
11

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>

12
13
14
15
16
#ifndef MARLIN_NAMESPACE_NAME
  #define MARLIN_NAMESPACE_NAME marlin
#endif

namespace MARLIN_NAMESPACE_NAME {
17
18

// Marlin params
19

20
21
22
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
23
24
static constexpr int default_threads = 256;

25
26
static constexpr int pipe_stages =
    4;  // 4 pipeline stages fit into shared memory
27
28
29

static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
30
static constexpr int max_thread_n = 256;
31
32

static constexpr int tile_size = 16;
33
static constexpr int max_par = 16;
34

35
36
37
38
39
40
41
42
43
// Repack params
static constexpr int repack_stages = 8;

static constexpr int repack_threads = 256;

static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;

// Helpers
44
45
template <typename T, int n>
struct Vec {
46
  T elems[n];
47
48
49
50
51
52
53
54
  __device__ T& operator[](int i) { return elems[i]; }
};

using I4 = Vec<int, 4>;

constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
55
// No support for async
56
57
#else

58
59
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
                                      bool pred = true) {
60
  const int BYTES = 16;
61
62
63
64
65
66
67
68
  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  asm volatile(
      "{\n"
      "   .reg .pred p;\n"
      "   setp.ne.b32 p, %0, 0;\n"
      "   @p cp.async.cg.shared.global [%1], [%2], %3;\n"
      "}\n" ::"r"((int)pred),
      "r"(smem), "l"(glob_ptr), "n"(BYTES));
69
70
}

71
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
72
  const int BYTES = 16;
73
74
75
76
77
78
  uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  asm volatile(
      "{\n"
      "   cp.async.cg.shared.global [%0], [%1], %2;\n"
      "}\n" ::"r"(smem),
      "l"(glob_ptr), "n"(BYTES));
79
80
}

81
82
83
__device__ inline void cp_async_fence() {
  asm volatile("cp.async.commit_group;\n" ::);
}
84
85
86
87
88
89
90
91

template <int n>
__device__ inline void cp_async_wait() {
  asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}

#endif

92
}  // namespace MARLIN_NAMESPACE_NAME