pad.py 1023 Bytes
Newer Older
1
2
3
4
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

Teddy Do's avatar
Teddy Do committed
5
"""PyTorch wrapper functions for padding Triton kernels."""
6
7
8
9

import torch
import triton

Teddy Do's avatar
Teddy Do committed
10
from transformer_engine.common.triton.pad import zero_pad_kernel
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor:
    """Pads a tensor assuming it's a columnwise scaling inverse."""

    assert inp.ndim == 2
    dim0, dim1 = inp.shape

    pad_x = (128 - dim0 % 128) % 128
    pad_y = (4 - dim1 % 4) % 4
    out_x = dim0 + pad_x
    out_y = dim1 + pad_y
    out = torch.empty((out_x, out_y), device=inp.device, dtype=inp.dtype)

    in_s0, in_s1 = inp.stride()
    out_s0, out_s1 = out.stride()

    BLOCK_M, BLOCK_N = 128, 128
    grid = (triton.cdiv(out_x, BLOCK_M), triton.cdiv(out_y, BLOCK_N))

    zero_pad_kernel[grid](
        inp,
        out,
        dim0,
        dim1,
        out_x,
        out_y,
        in_s0,
        in_s1,
        out_s0,
        out_s1,
    )
    return out