common.h 2.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include <cuda_runtime.h>
#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;

#define hexp cutlass::fast_exp
#define hlog cutlass::fast_log
#define hsqrt cutlass::fast_sqrt
#define htanh cutlass::fast_tanh
#define hpow powf

#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short

#define TL_DEVICE __forceinline__ __device__

// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
28
29
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
30
31
32
33
34
  return (v1 << 16) | v0;
}

// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
35
36
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
37
38
39
40
41
  return (v1 << 16) | v0;
}

// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
42
43
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
44
45
46
  return (v1 << 16) | v0;
}

47
// Helper to cast SMEM pointer to unsigned
48
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
49
50
51
  return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}

52
53
54
55
56
57
58
59
60
61
// Helper to cast SMEM pointer to unsigned
TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
  unsigned int smem_int;
  asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
               "cvt.u32.u64 %0, smem_int; }"
               : "=r"(smem_int)
               : "l"(smem_ptr));
  return smem_int;
}

62
// AtomicAdd Functions for FP16
63
TL_DEVICE void atomicAdd(half_t *address, half_t val) {
64
  // Use atomicCAS with built-in cuda_fp16 support
65
  atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
66
67
68
}

// AtomicAdd Functions for FP16
69
70
TL_DEVICE void atomicAdd(half_t *address, half_t *val) {
  atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
71
72
73
}

// AtomicAdd Functions for FP16
74
75
76
TL_DEVICE void atomicAddx2(half_t *address, half_t *val) {
  atomicAdd(reinterpret_cast<half2 *>(address),
            static_cast<half2>(*reinterpret_cast<half2 *>(val)));
77
78
}

79
TL_DEVICE void atomicAdd(half_t *address, float val) {
80
  // Use atomicCAS with built-in cuda_fp16 support
81
  atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
82
83
84
}

// DP4A
85
86
87
88
89
template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
  const int a_int = *((int *)a);
  const int b_int = *((int *)b);
  const int c_int = *((int *)c);
90
91
  *c = __dp4a(a_int, b_int, c_int);
}