curanddx.hpp 3.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_

namespace transformer_engine {
namespace curanddx {
namespace detail {

inline constexpr unsigned int philox4x32_w32_0 = 0x9E3779B9U;
inline constexpr unsigned int philox4x32_w32_1 = 0xBB67AE85U;
inline constexpr unsigned int philox4x32_m4x32_0 = 0xD2511F53U;
inline constexpr unsigned int philox4x32_m4x32_1 = 0xCD9E8D57U;

__forceinline__ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
                                                  unsigned int* hip) {
  *hip = __umulhi(a, b);
  return a * b;
}

__forceinline__ __device__ uint4 single_round(uint4 ctr, uint2 key) {
  unsigned int hi0;
  unsigned int hi1;
  unsigned int lo0 = mulhilo32(philox4x32_m4x32_0, ctr.x, &hi0);
  unsigned int lo1 = mulhilo32(philox4x32_m4x32_1, ctr.z, &hi1);

  uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
  return ret;
}

template <unsigned int Rounds>
__forceinline__ __device__ uint4 multiple_rounds(uint4 c, uint2 k) {
  for (unsigned int i = 0; i < Rounds - 1; i++) {
    c = single_round(c, k);  // 1
    k.x += philox4x32_w32_0;
    k.y += philox4x32_w32_1;
  }
  return single_round(c, k);  // Rounds
}

template <unsigned int Rounds>
struct philox4x32_native_state {
  static constexpr unsigned int rounds = Rounds;

  uint4 ctr;
  uint2 key;

  __forceinline__ __device__ void philox_state_incr() {
    if (++ctr.x) return;
    if (++ctr.y) return;
    if (++ctr.z) return;
    ++ctr.w;
  }

  __forceinline__ __device__ void philox_state_incr(size_t n) {
    unsigned int nlo = (unsigned int)(n);
    unsigned int nhi = (unsigned int)(n >> 32);

    ctr.x += nlo;
    if (ctr.x < nlo) nhi++;

    ctr.y += nhi;
    if (nhi <= ctr.y) return;
    if (++ctr.z) return;
    ++ctr.w;
  }

  __forceinline__ __device__ void philox_state_incr_hi(size_t n) {
    unsigned int nlo = (unsigned int)(n);
    unsigned int nhi = (unsigned int)(n >> 32);

    ctr.z += nlo;
    if (ctr.z < nlo) nhi++;

    ctr.w += nhi;
  }

  // offset is the total # of 128bits generated with a single generate4() call
  __forceinline__ __device__ void skip_offset(size_t n) { philox_state_incr(n); }

  __forceinline__ __device__ void skip_subsequence(size_t n) { philox_state_incr_hi(n); }

  __forceinline__ __device__ void init(size_t seed, size_t subsequence, size_t offset) {
    ctr = make_uint4(0, 0, 0, 0);
    key.x = (unsigned int)seed;
    key.y = (unsigned int)(seed >> 32);

    skip_subsequence(subsequence);
    skip_offset(offset);
  }

  __forceinline__ __device__ uint4 generate4() {
    auto tmp = multiple_rounds<Rounds>(ctr, key);
    philox_state_incr();
    return tmp;
  }
};
}  // namespace detail
}  // namespace curanddx
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_