"docs/debug.rst" did not exist on "201de5f743c2e37cd7d13f642bde1075ab79a19d"
Unverified Commit e9a5fa4e authored by Casper's avatar Casper Committed by GitHub
Browse files

[PyTorch] fix cross entropy vanishing gradients (#2139)



* fix cross entropy
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix comments
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: few more style issues
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: remove grad_output_stride (unnecessary)
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix: only backward was broken
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Generalize cross entropy backward kernel to handle reduced and unreduced loss
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 0f68f7b2
......@@ -6,6 +6,8 @@ import random
import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from utils import dtype_tols
class TestParallelCrossEntropy:
......@@ -18,19 +20,25 @@ class TestParallelCrossEntropy:
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
def generate_input(
self,
dtype: torch.dtype,
swap_dim: bool,
ignore_idx: bool,
device: torch.device = "cuda",
):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)
# Generate random data
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda()
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device)
else:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device)
if ignore_idx:
for i in ignore:
......@@ -40,9 +48,14 @@ class TestParallelCrossEntropy:
else:
self.tar_test[0][i] = -100
# Make copy of data for reference implementation
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
# Enable autograd
self.input_test.requires_grad_()
self.input_ref.requires_grad_()
def one_iteration_test(
self,
dtype: torch.dtype,
......@@ -52,18 +65,20 @@ class TestParallelCrossEntropy:
ignore_idx: bool = False,
):
# Random data
self.generate_input(dtype, swap_dim, ignore_idx)
self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)
# Forward pass
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
# Handle backward pass based on the test scenario
# Compute square to avoid trivial backward pass
test_loss = torch.square(test_loss)
ref_loss = torch.square(ref_loss)
# Backward pass
if reduce_loss:
test_loss.backward()
ref_loss.backward()
......@@ -71,16 +86,18 @@ class TestParallelCrossEntropy:
test_loss.sum().backward()
ref_loss.sum().backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
if ignore_idx:
print(test_loss, ref_loss)
# Compare gradients when backward pass was called
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
# Check that loss and grad input match
tols = dtype_tols(dtype)
test_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = ref_loss.reshape(test_loss.size())
test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = ref_grad_input.reshape(test_grad_input.size())
torch.testing.assert_close(test_loss, ref_loss, **tols)
torch.testing.assert_close(test_grad_input, ref_grad_input, **tols)
# Reset data
self.input_test = None
self.input_ref = None
self.tar_test = None
......
......@@ -230,6 +230,7 @@ def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
......@@ -252,6 +253,7 @@ def element_mul_kernel(
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
......@@ -360,6 +362,7 @@ def cross_entropy_backward(
_input,
_input.stride(-2),
grad_output,
1 if grad_output.numel() > 1 else 0,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment