custom.py 347 Bytes
Newer Older
maming's avatar
maming committed
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import torch

from megatron.energon.transforms.mappers import TransformMapper


class CustomTransform(TransformMapper[torch.nn.Module]):
    """Abstract class for additional custom transforms.
    Inherit from this and override at least `apply_transform`.
    """