Unverified Commit c6c1f50e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Add ops for dropout and constant scale (#1995)



* Add ops for dropout and constant scale
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 38c26dd8
......@@ -36,9 +36,7 @@ from transformer_engine.pytorch.utils import is_bf16_compatible
import transformer_engine_torch as tex
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent))
from utils import dtype_tols, make_recipe
from utils import dtype_tols, make_recipe, reset_rng_states
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
......@@ -327,10 +325,7 @@ class TestFuser:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_scale_update(
......@@ -544,10 +539,7 @@ class TestBasicOps:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", ("cuda", "cpu"))
......@@ -1693,16 +1685,107 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("device", _devices)
def test_constant_scale(
self,
*,
scale: float,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device,
):
# Random data
x_ref, x_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
y_ref = scale * x_ref
y_ref.backward(dy_ref)
# Implementation with fusible operation
op = te_ops.ConstantScale(scale)
y_test = op(x_test)
y_test.backward(dy_test)
# Check results
tols = dtype_tols(dtype)
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75))
@pytest.mark.parametrize("is_training", (True, False))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
def test_dropout(
self,
*,
prob: float,
is_training: bool,
shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
):
# Random data
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
x_test = x_ref.clone().requires_grad_()
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5
dy_test = dy_ref.clone()
# Apply dropout
op = te_ops.Dropout(prob)
if is_training:
op.train()
else:
op.eval()
y = op(x_test)
y.backward(dy_test)
# Check values
if is_training:
mask = ((y != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(y, x_ref * mask)
torch.testing.assert_close(x_test.grad, dy_ref * mask)
else:
torch.testing.assert_close(y, x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0)
# Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has
# mean p and standard deviation sqrt(p*(1-p)). By the central
# limit theorem, the mean of n iid Bernoulli variables
# converges to a normal random variable with mean p and
# standard deviation sqrt(p*(1-p)/n). If the observed mean is
# below the 0.5th or above the 99.5th percentiles, then the
# p-value is less than 1% and we assume that the dropout
# distribution is incorrect.
if is_training:
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel())
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval"
class TestFusedOps:
"""Tests for fused operations"""
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5)))
@pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1)))
......@@ -2125,10 +2208,7 @@ class TestCheckpointing:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantized_weight", (False, True))
......@@ -2240,10 +2320,7 @@ class TestSequentialModules:
@staticmethod
def setup_class(cls) -> None:
# Configure RNG
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
reset_rng_states()
@pytest.mark.parametrize("requires_grad", (False, True))
@pytest.mark.parametrize("bias", (False, True))
......
......@@ -10,6 +10,8 @@ from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .constant_scale import ConstantScale
from .dropout import Dropout
from .identity import Identity
from .l2normalization import L2Normalization
from .layer_norm import LayerNorm
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for constant scaling."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class ConstantScale(BasicOperation):
"""Multiply by a constant"""
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
return input_ * self.scale
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
return grad_output * self.scale, ()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fusible operation for dropout."""
from __future__ import annotations
from typing import Optional
import torch
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer
class Dropout(BasicOperation):
"""Randomly zero out tensor entries during training
During training, tensor entries are randomly set to zero with
probability :math:`p` and remaining entries are scaled by
:math:`1/(1-p)`.
"""
def __init__(self, p: float) -> None:
super().__init__()
self.dropout_probability = p
def op_forward(
self,
ctx: OperationContext,
input_: torch.Tensor,
prev_op_grad_output_quantizer: Optional[Quantizer],
next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor:
# Compute dropout if training
out = input_
is_training = self.training
mask = None
if is_training:
keep_prob = 1 - self.dropout_probability
mask = torch.empty_like(input_)
mask.bernoulli_(keep_prob)
mask *= 1 / keep_prob
out = out * mask
# Save context for backward
if ctx.requires_grad:
ctx.save_for_backward(mask)
ctx.is_training = is_training
return out
def op_backward(
self,
ctx: OperationContext,
grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]:
(mask,) = ctx.saved_tensors
grad_input = grad_output
if ctx.is_training:
grad_input = grad_input * mask
return grad_input, ()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment