"vscode:/vscode.git/clone" did not exist on "bc12d4033f3e49314a837249288d5012d1bf7501"
ldsm.h 4.86 KB
Newer Older
1
2
3
4
5
6
#pragma once

#include "common.h"

namespace tl {

7
8
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
                                        void *const local_ptr) {
9
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
10
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
11
12
13
14
15
  asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
               : "=r"(value[0])
               : "r"(smem_int_ptr));
}

16
17
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
                                        void *const local_ptr) {
18
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
19
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
20
21
22
23
24
  asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
               : "=r"(value[0]), "=r"(value[1])
               : "r"(smem_int_ptr));
}

25
26
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
                                        void *const local_ptr) {
27
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
28
29
30
31
32
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
  asm volatile(
      "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
      : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
      : "r"(smem_int_ptr));
33
34
}

35
36
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
                                              void *const local_ptr) {
37
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
38
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
39
40
41
42
43
  asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
               : "=r"(value[0])
               : "r"(smem_int_ptr));
}

44
45
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
                                              void *const local_ptr) {
46
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
47
48
49
50
51
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
  asm volatile(
      "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
      : "=r"(value[0]), "=r"(value[1])
      : "r"(smem_int_ptr));
52
53
}

54
55
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
                                              void *const local_ptr) {
56
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
57
58
59
60
61
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
  asm volatile(
      "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
      : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
      : "r"(smem_int_ptr));
62
63
}

64
65
TL_DEVICE void ptx_stmatrix_x1(void const *const smem_ptr,
                               const int32_t &value0) {
66
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
67
68
  asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
                   smem_int_ptr),
69
70
71
               "r"(value0));
}

72
73
TL_DEVICE void ptx_stmatrix_x2(void const *const smem_ptr,
                               const int32_t &value0, const int32_t &value1) {
74
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
75
76
77
78
  asm volatile(
      "stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
          smem_int_ptr),
      "r"(value0), "r"(value1));
79
80
}

81
82
83
TL_DEVICE void ptx_stmatrix_x4(void const *const smem_ptr,
                               const int32_t &value0, const int32_t &value1,
                               const int32_t &value2, const int32_t &value3) {
84
85
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
  asm volatile(
86
87
      "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::
          "r"(smem_int_ptr),
88
89
90
      "r"(value0), "r"(value1), "r"(value2), "r"(value3));
}

91
92
TL_DEVICE void ptx_stmatrix_x1_trans(void const *const smem_ptr,
                                     const int32_t &value0) {
93
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
94
95
96
97
  asm volatile(
      "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
          smem_int_ptr),
      "r"(value0));
98
99
}

100
101
102
TL_DEVICE void ptx_stmatrix_x2_trans(void const *const smem_ptr,
                                     const int32_t &value0,
                                     const int32_t &value1) {
103
104
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
  asm volatile(
105
106
      "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
          smem_int_ptr),
107
108
109
      "r"(value0), "r"(value1));
}

110
111
112
113
114
TL_DEVICE void ptx_stmatrix_x4_trans(void const *const smem_ptr,
                                     const int32_t &value0,
                                     const int32_t &value1,
                                     const int32_t &value2,
                                     const int32_t &value3) {
115
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
116
117
  asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, "
               "%3, %4};\n" ::"r"(smem_int_ptr),
118
119
120
               "r"(value0), "r"(value1), "r"(value2), "r"(value3));
}

121
} // namespace tl