".github/vscode:/vscode.git/clone" did not exist on "ba0bfd40e21cacfd5da6a1e43028a37258a29cb4"
Unverified Commit e9a5fa4e authored by Casper's avatar Casper Committed by GitHub
Browse files

[PyTorch] fix cross entropy vanishing gradients (#2139)



* fix cross entropy
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix comments
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: few more style issues
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: remove grad_output_stride (unnecessary)
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix: only backward was broken
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Generalize cross entropy backward kernel to handle reduced and unreduced loss
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 0f68f7b2
...@@ -6,6 +6,8 @@ import random ...@@ -6,6 +6,8 @@ import random
import torch import torch
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from utils import dtype_tols
class TestParallelCrossEntropy: class TestParallelCrossEntropy:
...@@ -18,19 +20,25 @@ class TestParallelCrossEntropy: ...@@ -18,19 +20,25 @@ class TestParallelCrossEntropy:
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
) )
def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): def generate_input(
self,
dtype: torch.dtype,
swap_dim: bool,
ignore_idx: bool,
device: torch.device = "cuda",
):
SQ = random.choice([64, 128]) SQ = random.choice([64, 128])
batch = random.choice([1, 2]) batch = random.choice([1, 2])
vocab = random.choice([64000, 128000]) vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5) ignore = random.sample(range(0, SQ - 1), 5)
# Generate random data
if swap_dim: if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device)
else: else:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device)
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device)
if ignore_idx: if ignore_idx:
for i in ignore: for i in ignore:
...@@ -40,9 +48,14 @@ class TestParallelCrossEntropy: ...@@ -40,9 +48,14 @@ class TestParallelCrossEntropy:
else: else:
self.tar_test[0][i] = -100 self.tar_test[0][i] = -100
# Make copy of data for reference implementation
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
# Enable autograd
self.input_test.requires_grad_()
self.input_ref.requires_grad_()
def one_iteration_test( def one_iteration_test(
self, self,
dtype: torch.dtype, dtype: torch.dtype,
...@@ -52,18 +65,20 @@ class TestParallelCrossEntropy: ...@@ -52,18 +65,20 @@ class TestParallelCrossEntropy:
ignore_idx: bool = False, ignore_idx: bool = False,
): ):
# Random data
self.generate_input(dtype, swap_dim, ignore_idx) self.generate_input(dtype, swap_dim, ignore_idx)
self.input_test.requires_grad_(True) # Forward pass
self.input_ref.requires_grad_(True)
test_loss = self.test_loss_func( test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None self.input_test, self.tar_test, label_smoothing, reduce_loss, None
) )
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
# Handle backward pass based on the test scenario # Compute square to avoid trivial backward pass
test_loss = torch.square(test_loss)
ref_loss = torch.square(ref_loss)
# Backward pass
if reduce_loss: if reduce_loss:
test_loss.backward() test_loss.backward()
ref_loss.backward() ref_loss.backward()
...@@ -71,16 +86,18 @@ class TestParallelCrossEntropy: ...@@ -71,16 +86,18 @@ class TestParallelCrossEntropy:
test_loss.sum().backward() test_loss.sum().backward()
ref_loss.sum().backward() ref_loss.sum().backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss # Check that loss and grad input match
tols = dtype_tols(dtype)
if ignore_idx: test_loss = test_loss.to(dtype=torch.float64, device="cpu")
print(test_loss, ref_loss) ref_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = ref_loss.reshape(test_loss.size())
# Compare gradients when backward pass was called test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close( ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad ref_grad_input = ref_grad_input.reshape(test_grad_input.size())
) torch.testing.assert_close(test_loss, ref_loss, **tols)
torch.testing.assert_close(test_grad_input, ref_grad_input, **tols)
# Reset data
self.input_test = None self.input_test = None
self.input_ref = None self.input_ref = None
self.tar_test = None self.tar_test = None
......
...@@ -230,6 +230,7 @@ def element_mul_kernel( ...@@ -230,6 +230,7 @@ def element_mul_kernel(
X_ptr, X_ptr,
X_stride, X_stride,
grad_output_ptr, grad_output_ptr,
grad_output_stride,
n_cols, n_cols,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
...@@ -252,6 +253,7 @@ def element_mul_kernel( ...@@ -252,6 +253,7 @@ def element_mul_kernel(
X_ptr += program_id * X_stride X_ptr += program_id * X_stride
# Load the gradient output value # Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr) grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication # Perform the element-wise multiplication
...@@ -360,6 +362,7 @@ def cross_entropy_backward( ...@@ -360,6 +362,7 @@ def cross_entropy_backward(
_input, _input,
_input.stride(-2), _input.stride(-2),
grad_output, grad_output,
1 if grad_output.numel() > 1 else 0,
V, V,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=32, num_warps=32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment