rand.py 4.97 KB
Newer Older
1
2
from typing import Optional, Union

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import triton
import triton.language as tl


def seeded_uniform(
    *size,
    seeds: torch.Tensor,
    out: Optional[torch.Tensor] = None,
    dtype: Optional[torch.dtype] = None,
    device: Optional[Union[torch.device, str]] = None,
    pin_memory: Optional[bool] = False,
) -> torch.Tensor:
    """Similar to torch.rand, but allows for seeds to be set per row.

    seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
    If it is 3d, the additional seeds needed will be derived automatically
    in a deterministic fashion:
    [
        row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
    ]
    """
    n_dims = len(size)

    if n_dims > 3:
        raise ValueError("seeded_uniform only supports up to 3D tensors")

    if out is None:
        out = torch.empty(*size,
                          dtype=dtype,
                          device=device,
                          pin_memory=pin_memory)
    elif out.shape != size:
        raise ValueError("shape of out and size must be the same")

    if n_dims == 3:
        n_rows, n_3d, n_cols = out.shape
        stride_row = out.stride(0)
        stride_3d = out.stride(1)
    elif n_dims == 2:
        n_rows, n_cols = out.shape
        n_3d = 1
        stride_row = out.stride(0)
        stride_3d = 1
    else:
        n_cols = out.shape[0]
        n_rows = 1
        n_3d = 1
        stride_row = 1
        stride_3d = 1

    if seeds.ndim != 1:
        raise ValueError("seeds must be a 1D tensor")

    if seeds.numel() != n_rows:
        raise ValueError(
            "seeds must have the same number of elements as out has rows")

    # The philox PRNG Triton uses generates 4 random numbers at once.
    # Therefore, the most efficient use of it is to divide the
    # block size by 4, and then save the generated random numbers to
    # each of the 4 slices of the tensor.
    full_block_size = triton.next_power_of_2(n_cols)
    philox_block_size = max(full_block_size // 4, 1)
    n_slices = full_block_size // philox_block_size
    num_warps = 4
    # Manual tuning. This seems to give best performance on A100 for
    # simple kernels like this.
    if philox_block_size >= 8192:
        num_warps = 32
    elif philox_block_size >= 4096:
        num_warps = 16
    elif philox_block_size >= 2048:
        num_warps = 8

    _seeded_uniform_triton[(n_rows, n_3d)](
        out,
        seeds,
        stride_row,
        stride_3d,
        seeds.stride(0),
        n_rows,
        n_3d,
        n_cols,
        n_slices=n_slices,
        num_warps=num_warps,
        block_size=philox_block_size,
    )
    return out


@triton.jit
def _seeded_uniform_triton(
    out_ptr: torch.Tensor,
    seed_ptr: torch.Tensor,
    out_row_stride: int,
    out_3d_stride: int,
    seed_row_stride: int,
    n_rows: int,
    n_3d: int,
    n_cols: int,
    n_slices: tl.constexpr,
    block_size: tl.constexpr,
):
    """
    Generate a random float32 number in [0, 1) for each element in the output
    tensor. The random numbers in a row generated using the seed for that row.

    Args:
        out_ptr: The output tensor.
        seed_ptr: The per-row seeds to use for random number generation.
        out_row_stride: The stride between rows of the output tensor.
        out_3d_stride: The stride between 3D slices of the output tensor.
        seed_row_stride: The stride between rows of the seed tensor.
        n_rows: The number of rows in the output tensor.
        n_3d: The size of second dimension of the output tensor,
            if output tensor is 3D.
        n_cols: The number of columns in the output tensor.
        n_slices: The number of philox outputs to use.
    """
    tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")

    # Get the row index.
    row_idx = tl.program_id(axis=0)
    three_d_idx = tl.program_id(axis=1)

    philox_offsets = tl.arange(0, block_size)
    # Get the seed for the current element.
    seed = tl.load(seed_ptr + row_idx * seed_row_stride)
    if three_d_idx > 0:
        seed ^= three_d_idx
    # Generate random numbers in [0, 1).
    out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)

    output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
                            three_d_idx * out_3d_stride)
    out1_offsets = philox_offsets
    tl.store(output_row_start_ptr + out1_offsets,
             out1,
             mask=out1_offsets < n_cols)
    if n_slices > 1:
        out2_offsets = tl.arange(block_size, block_size * 2)
        tl.store(output_row_start_ptr + out2_offsets,
                 out2,
                 mask=out2_offsets < n_cols)
    if n_slices > 2:
        out3_offsets = tl.arange(block_size * 2, block_size * 3)
        tl.store(output_row_start_ptr + out3_offsets,
                 out3,
                 mask=out3_offsets < n_cols)
    if n_slices > 3:
        out4_offsets = tl.arange(block_size * 3, block_size * 4)
        tl.store(output_row_start_ptr + out4_offsets,
                 out4,
                 mask=out4_offsets < n_cols)