common.h 2.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once

#include <cuda_runtime.h>
#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;

#define hexp cutlass::fast_exp
#define hlog cutlass::fast_log
#define hsqrt cutlass::fast_sqrt
#define htanh cutlass::fast_tanh
#define hpow powf

#define uint unsigned int
#define uchar unsigned char
#define ushort unsigned short

#define TL_DEVICE __forceinline__ __device__

// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
28
29
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
30
31
32
33
34
  return (v1 << 16) | v0;
}

// Pack two half_t values.
TL_DEVICE unsigned __pack_half2(const half_t x, const half_t y) {
35
36
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
37
38
39
40
41
  return (v1 << 16) | v0;
}

// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
42
43
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
44
45
46
47
  return (v1 << 16) | v0;
}

/// Helper to cast SMEM pointer to unsigned
48
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
49
50
51
52
  return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}

// AtomicAdd Functions for FP16
53
TL_DEVICE void atomicAdd(half_t *address, half_t val) {
54
  // Use atomicCAS with built-in cuda_fp16 support
55
  atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
56
57
58
}

// AtomicAdd Functions for FP16
59
60
TL_DEVICE void atomicAdd(half_t *address, half_t *val) {
  atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
61
62
63
}

// AtomicAdd Functions for FP16
64
65
66
TL_DEVICE void atomicAddx2(half_t *address, half_t *val) {
  atomicAdd(reinterpret_cast<half2 *>(address),
            static_cast<half2>(*reinterpret_cast<half2 *>(val)));
67
68
}

69
TL_DEVICE void atomicAdd(half_t *address, float val) {
70
  // Use atomicCAS with built-in cuda_fp16 support
71
  atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
72
73
74
}

// DP4A
75
76
77
78
79
template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
  const int a_int = *((int *)a);
  const int b_int = *((int *)b);
  const int c_int = *((int *)c);
80
81
  *c = __dp4a(a_int, b_int, c_int);
}