ldsm.h 4.94 KB
Newer Older
1
2
3
4
5
6
7
8
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include "common.h"

namespace tl {

9
10
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
                                        void *const local_ptr) {
11
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
12
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
13
14
15
16
17
  asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
               : "=r"(value[0])
               : "r"(smem_int_ptr));
}

18
19
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
                                        void *const local_ptr) {
20
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
21
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
22
23
24
25
26
  asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
               : "=r"(value[0]), "=r"(value[1])
               : "r"(smem_int_ptr));
}

27
28
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
                                        void *const local_ptr) {
29
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
30
31
32
33
34
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
  asm volatile(
      "ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
      : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
      : "r"(smem_int_ptr));
35
36
}

37
38
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
                                              void *const local_ptr) {
39
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
40
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
41
42
43
44
45
  asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
               : "=r"(value[0])
               : "r"(smem_int_ptr));
}

46
47
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
                                              void *const local_ptr) {
48
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
49
50
51
52
53
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
  asm volatile(
      "ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n"
      : "=r"(value[0]), "=r"(value[1])
      : "r"(smem_int_ptr));
54
55
}

56
57
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
                                              void *const local_ptr) {
58
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
59
60
61
62
63
  int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
  asm volatile(
      "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
      : "=r"(value[0]), "=r"(value[1]), "=r"(value[2]), "=r"(value[3])
      : "r"(smem_int_ptr));
64
65
}

66
67
TL_DEVICE void ptx_stmatrix_x1(void const *const smem_ptr,
                               const int32_t &value0) {
68
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
69
70
  asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
                   smem_int_ptr),
71
72
73
               "r"(value0));
}

74
75
TL_DEVICE void ptx_stmatrix_x2(void const *const smem_ptr,
                               const int32_t &value0, const int32_t &value1) {
76
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
77
78
79
80
  asm volatile(
      "stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
          smem_int_ptr),
      "r"(value0), "r"(value1));
81
82
}

83
84
85
TL_DEVICE void ptx_stmatrix_x4(void const *const smem_ptr,
                               const int32_t &value0, const int32_t &value1,
                               const int32_t &value2, const int32_t &value3) {
86
87
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
  asm volatile(
88
89
      "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::
          "r"(smem_int_ptr),
90
91
92
      "r"(value0), "r"(value1), "r"(value2), "r"(value3));
}

93
94
TL_DEVICE void ptx_stmatrix_x1_trans(void const *const smem_ptr,
                                     const int32_t &value0) {
95
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
96
97
98
99
  asm volatile(
      "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" ::"r"(
          smem_int_ptr),
      "r"(value0));
100
101
}

102
103
104
TL_DEVICE void ptx_stmatrix_x2_trans(void const *const smem_ptr,
                                     const int32_t &value0,
                                     const int32_t &value1) {
105
106
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
  asm volatile(
107
108
      "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" ::"r"(
          smem_int_ptr),
109
110
111
      "r"(value0), "r"(value1));
}

112
113
114
115
116
TL_DEVICE void ptx_stmatrix_x4_trans(void const *const smem_ptr,
                                     const int32_t &value0,
                                     const int32_t &value1,
                                     const int32_t &value2,
                                     const int32_t &value3) {
117
  uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
118
119
  asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, "
               "%3, %4};\n" ::"r"(smem_int_ptr),
120
121
122
               "r"(value0), "r"(value1), "r"(value2), "r"(value3));
}

123
} // namespace tl