grammar.py 492 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import List, Optional, Union

import torch


def apply_token_bitmask_inplace_cuda(
    logits: torch.Tensor,
    bitmask: torch.Tensor,
    indices: Optional[Union[List[int], torch.Tensor]] = None,
) -> None:
    if isinstance(indices, list):
        indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
    if indices is not None:
        indices = indices.to(logits.device)
    torch.ops.sgl_kernel.apply_token_bitmask_inplace_cuda(logits, bitmask, indices)