Unverified Commit 36f2dfd2 authored by Yashaswi Karnati's avatar Yashaswi Karnati Committed by GitHub
Browse files

fix ce loss calculation when some tokens are ignored (#2476)



* fix ce loss with ignore idx
Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>

* remove fix comments
Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>

* fallback divisor to 1
Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>

* have arg for n_rows and n_non_ignore
Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>

* fuse n_non_ignore to softmax kernel
Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>

* fix incorrect arg
Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>

---------
Signed-off-by: default avatarykarnati <ykarnati@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8c9f7c25
...@@ -89,7 +89,7 @@ class TestParallelCrossEntropy: ...@@ -89,7 +89,7 @@ class TestParallelCrossEntropy:
# Check that loss and grad input match # Check that loss and grad input match
tols = dtype_tols(dtype) tols = dtype_tols(dtype)
test_loss = test_loss.to(dtype=torch.float64, device="cpu") test_loss = test_loss.to(dtype=torch.float64, device="cpu")
ref_loss = test_loss.to(dtype=torch.float64, device="cpu") ref_loss = ref_loss.to(dtype=torch.float64, device="cpu")
ref_loss = ref_loss.reshape(test_loss.size()) ref_loss = ref_loss.reshape(test_loss.size())
test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
...@@ -154,3 +154,16 @@ class TestParallelCrossEntropy: ...@@ -154,3 +154,16 @@ class TestParallelCrossEntropy:
reduce_loss=False, reduce_loss=False,
ignore_idx=True, ignore_idx=True,
) )
def test_ignore_idx_reduced_loss(self):
"""Test ignore_idx with reduce_loss=True"""
self.generate_iters(5)
self.generate_infra(True, 0) # reduce_loss=True
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=True,
ignore_idx=True,
)
...@@ -18,6 +18,8 @@ def online_softmax_kernel( ...@@ -18,6 +18,8 @@ def online_softmax_kernel(
m_d_X_y_stride, m_d_X_y_stride,
rank, rank,
n_cols, n_cols,
ignore_idx,
n_non_ignore,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
""" """
...@@ -32,6 +34,8 @@ def online_softmax_kernel( ...@@ -32,6 +34,8 @@ def online_softmax_kernel(
m_d_X_y_stride (int): The stride of the m/d/X_y tensor. m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group. rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor. n_cols (int): The number of columns in the input tensor.
ignore_idx (int): The index to ignore for loss calculation.
n_non_ignore: The number of non-ignored elements in the batch.
BLOCK_SIZE (int): The block size for Triton operations. BLOCK_SIZE (int): The block size for Triton operations.
""" """
...@@ -44,6 +48,9 @@ def online_softmax_kernel( ...@@ -44,6 +48,9 @@ def online_softmax_kernel(
Y_ptr += program_id * Y_stride Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr) y = tl.load(Y_ptr)
if y != ignore_idx:
tl.atomic_add(n_non_ignore, 1)
vocab_start_idx = rank * n_cols vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx: if y >= vocab_start_idx:
...@@ -89,6 +96,7 @@ def cross_entropy_kernel( ...@@ -89,6 +96,7 @@ def cross_entropy_kernel(
world_size, world_size,
ignore_idx, ignore_idx,
n_cols, n_cols,
n_rows,
n_non_ignore, n_non_ignore,
reduce_loss: tl.constexpr, reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr, label_smoothing: tl.constexpr,
...@@ -110,12 +118,14 @@ def cross_entropy_kernel( ...@@ -110,12 +118,14 @@ def cross_entropy_kernel(
world_size (int): The size of world involved in this distributed loss calculation. world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation. ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor. n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch. n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing.
n_non_ignore: The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations. BLOCK_SIZE (int): The block size for Triton operations.
""" """
program_id = tl.program_id(0).to(tl.int64) program_id = tl.program_id(0).to(tl.int64)
n_non_ignore = tl.load(n_non_ignore)
# locate the start index # locate the start index
X_ptr += program_id * X_stride X_ptr += program_id * X_stride
...@@ -140,7 +150,7 @@ def cross_entropy_kernel( ...@@ -140,7 +150,7 @@ def cross_entropy_kernel(
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
for i in range(1, world_size): for i in range(1, world_size):
offset = i * 3 * n_non_ignore * m_d_X_y_stride offset = i * 3 * n_rows * m_d_X_y_stride
access_ptr = m_d_X_y_ptr + offset access_ptr = m_d_X_y_ptr + offset
m_new = tl.load(access_ptr) m_new = tl.load(access_ptr)
d_new = tl.load(access_ptr + m_d_X_y_stride) d_new = tl.load(access_ptr + m_d_X_y_stride)
......
...@@ -46,6 +46,8 @@ def cross_entropy_forward( ...@@ -46,6 +46,8 @@ def cross_entropy_forward(
# tensor to hold this rank's m/d/X_y values # tensor to hold this rank's m/d/X_y values
m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device) m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device)
n_non_ignore = torch.zeros(1, dtype=torch.int64, device=_input.device)
# ensure _input and target are contiguous in the last dimension # ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1: if _input.stride(-1) != 1:
_input = _input.contiguous() _input = _input.contiguous()
...@@ -63,10 +65,14 @@ def cross_entropy_forward( ...@@ -63,10 +65,14 @@ def cross_entropy_forward(
m_d_X_y_stride=m_d_X_y.stride(-1), m_d_X_y_stride=m_d_X_y.stride(-1),
rank=rank, rank=rank,
n_cols=V, n_cols=V,
ignore_idx=ignore_idx,
n_non_ignore=n_non_ignore,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=32, num_warps=32,
) )
n_non_ignore = torch.clamp(n_non_ignore, min=1)
world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)
if world_size > 1: if world_size > 1:
...@@ -90,14 +96,17 @@ def cross_entropy_forward( ...@@ -90,14 +96,17 @@ def cross_entropy_forward(
world_size=world_size, world_size=world_size,
ignore_idx=ignore_idx, ignore_idx=ignore_idx,
n_cols=V, n_cols=V,
n_non_ignore=n_rows, n_rows=n_rows,
n_non_ignore=n_non_ignore,
reduce_loss=reduce_loss, reduce_loss=reduce_loss,
label_smoothing=label_smoothing, label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=32, num_warps=32,
) )
loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) loss = (
torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_non_ignore)
)
return loss, _input return loss, _input
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment