"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "3d46bf61e3bb336f15ef063b5d72fc3454eb53c2"
Unverified Commit 42b51c40 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Added token ignoring for CE loss (#1789)



* Added token ignoring for CE loss
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added tests
Signed-off-by: default avatarroot <root@cw-dfw-h100-004-210-013.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 27612051
......@@ -19,11 +19,12 @@ class TestParallelCrossEntropy:
label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none"
)
def generate_input(self, dtype: torch.dtype, swap_dim: bool):
def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool):
SQ = random.choice([64, 128])
batch = random.choice([1, 2])
vocab = random.choice([64000, 128000])
ignore = random.sample(range(0, SQ - 1), 5)
if swap_dim:
self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda()
......@@ -32,14 +33,27 @@ class TestParallelCrossEntropy:
self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda()
self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda()
if ignore_idx:
for i in ignore:
# Ignore 5 indices
if swap_dim:
self.tar_test[i][0] = -100
else:
self.tar_test[0][i] = -100
self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab))
self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,))
def one_iteration_test(
self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool
self,
dtype: torch.dtype,
swap_dim: bool,
label_smoothing: float,
reduce_loss: bool,
ignore_idx: bool = False,
):
self.generate_input(dtype, swap_dim)
self.generate_input(dtype, swap_dim, ignore_idx)
self.input_test.requires_grad_(True)
self.input_ref.requires_grad_(True)
......@@ -57,6 +71,8 @@ class TestParallelCrossEntropy:
test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss
torch.testing.assert_close(test_loss, ref_loss, check_dtype=False)
if ignore_idx:
print(test_loss, ref_loss)
if reduce_loss:
torch.testing.assert_close(
torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad
......@@ -106,3 +122,15 @@ class TestParallelCrossEntropy:
self.one_iteration_test(
dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False
)
def test_ignore_idx(self):
self.generate_iters(5)
self.generate_infra(False, 0)
for i in range(self.iters):
self.one_iteration_test(
dtype=torch.float32,
swap_dim=random.choice([True, False]),
label_smoothing=0,
reduce_loss=False,
ignore_idx=True,
)
......@@ -22,7 +22,13 @@ class CrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx, _input, target, label_smoothing=0.0, reduce_loss=False, dist_process_group=None
ctx,
_input,
target,
label_smoothing=0.0,
reduce_loss=False,
dist_process_group=None,
ignore_idx=-100,
):
"""
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
......@@ -35,12 +41,13 @@ class CrossEntropyFunction(torch.autograd.Function):
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension.
dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device.
ignore_idx (int): The index for which loss and gradients are made to zero
Returns:
tensor: The computed loss.
"""
loss, _input = triton_cross_entropy.cross_entropy_forward(
_input, target, label_smoothing, reduce_loss, dist_process_group
_input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx
)
ctx.save_for_backward(_input.detach())
......
......@@ -94,6 +94,7 @@ def cross_entropy_kernel(
m_d_X_y_stride,
rank,
world_size,
ignore_idx,
n_cols,
n_non_ignore,
label_smoothing: tl.constexpr,
......@@ -113,6 +114,7 @@ def cross_entropy_kernel(
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
......@@ -128,6 +130,13 @@ def cross_entropy_kernel(
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
if y == ignore_idx:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return
loss_ptr += program_id * loss_stride
m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride
......@@ -247,6 +256,7 @@ def cross_entropy_forward(
label_smoothing: float,
reduce_loss: bool,
dist_process_group: Union[dist.ProcessGroup, None],
ignore_idx: int,
):
"""Forward implementation of Cross Entropy kernel"""
......@@ -305,6 +315,7 @@ def cross_entropy_forward(
m_d_X_y_stride=m_d_X_y_gathered.stride(-1),
rank=rank,
world_size=world_size,
ignore_idx=ignore_idx,
n_cols=V,
n_non_ignore=n_rows,
label_smoothing=label_smoothing,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment