Unverified Commit 8ce49c01 authored by Li Tao's avatar Li Tao Committed by GitHub
Browse files

[Pytorch] Bugfix in te fusion ce implementation (#1879)



* Fix an issue when mcore uses te fusion ce implementation
Signed-off-by: default avatarlit <lit@nvidia.com>

* simplify unit test code
Signed-off-by: default avatarlit <lit@nvidia.com>

* Update tests/pytorch/test_parallel_cross_entropy.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarlit <lit@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 01a504c4
......@@ -61,22 +61,26 @@ class TestParallelCrossEntropy:
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)
if reduce_loss:
test_loss.backward()
ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref)
# Handle backward pass based on the test scenario
if reduce_loss:
test_loss.backward()
ref_loss.backward()
else:
test_loss.sum().backward()
ref_loss.sum().backward()
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx:
print(test_loss, ref_loss)
if reduce_loss:
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
# Compare gradients when backward pass was called
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
)
self.input_test = None
self.input_ref = None
......
......@@ -97,6 +97,7 @@ def cross_entropy_kernel(
ignore_idx,
n_cols,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
......@@ -176,7 +177,13 @@ def cross_entropy_kernel(
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
......@@ -204,7 +211,11 @@ def cross_entropy_kernel(
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
X_y += -(1 - label_smoothing) / (n_non_ignore)
# Apply the same conditional scaling logic for the target token
if reduce_loss:
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
tl.store(X_ptr + y - vocab_start_idx, X_y)
tl.store(loss_ptr, loss)
......@@ -318,6 +329,7 @@ def cross_entropy_forward(
ignore_idx=ignore_idx,
n_cols=V,
n_non_ignore=n_rows,
reduce_loss=reduce_loss,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment