sample.py 423 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import math

# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072


def get_num_triton_sampler_splits(n_cols: int) -> int:
    """Get the number of splits to use for Triton sampling.

    Triton has a limit on the number of columns it can handle, so we need to
    split the tensor and call the kernel multiple times if it's too large.
    """
    return math.ceil(n_cols / MAX_TRITON_N_COLS)