"deploy/cpp_infer/src_system/ocr_cls.cpp" did not exist on "9336e344302e2f0717dd985f9a98e5f7b28576f9"
Commit 08124c8f authored by Tri Dao's avatar Tri Dao
Browse files

[CrossEntropy] Implement logit_scale option

parent 9356a1c0
...@@ -12,6 +12,7 @@ class CrossEntropyLoss(nn.Module): ...@@ -12,6 +12,7 @@ class CrossEntropyLoss(nn.Module):
ignore_index=-100, ignore_index=-100,
reduction="mean", reduction="mean",
label_smoothing=0.0, label_smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0, lse_square_scale=0.0,
inplace_backward=False, inplace_backward=False,
process_group=None, process_group=None,
...@@ -33,6 +34,7 @@ class CrossEntropyLoss(nn.Module): ...@@ -33,6 +34,7 @@ class CrossEntropyLoss(nn.Module):
self.ignore_index = ignore_index self.ignore_index = ignore_index
self.reduction = reduction self.reduction = reduction
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
self.logit_scale = logit_scale
self.lse_square_scale = lse_square_scale self.lse_square_scale = lse_square_scale
self.inplace_backward = inplace_backward self.inplace_backward = inplace_backward
self.process_group = process_group self.process_group = process_group
...@@ -50,6 +52,7 @@ class CrossEntropyLoss(nn.Module): ...@@ -50,6 +52,7 @@ class CrossEntropyLoss(nn.Module):
input, input,
target, target,
label_smoothing=self.label_smoothing, label_smoothing=self.label_smoothing,
logit_scale=self.logit_scale,
lse_square_scale=self.lse_square_scale, lse_square_scale=self.lse_square_scale,
ignored_index=self.ignore_index, ignored_index=self.ignore_index,
inplace_backward=self.inplace_backward, inplace_backward=self.inplace_backward,
......
...@@ -29,6 +29,7 @@ def cross_entropy_fwd_kernel( ...@@ -29,6 +29,7 @@ def cross_entropy_fwd_kernel(
logits_ptr, logits_ptr,
labels_ptr, labels_ptr,
smoothing, smoothing,
logit_scale,
lse_square_scale, lse_square_scale,
ignored_index, ignored_index,
total_classes, total_classes,
...@@ -48,7 +49,7 @@ def cross_entropy_fwd_kernel( ...@@ -48,7 +49,7 @@ def cross_entropy_fwd_kernel(
label_idx = tl.load(labels_ptr + row_idx) label_idx = tl.load(labels_ptr + row_idx)
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32 tl.float32
) ) * logit_scale
max_logits = tl.max(logits, 0) max_logits = tl.max(logits, 0)
if HAS_SMOOTHING: if HAS_SMOOTHING:
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)
...@@ -61,7 +62,7 @@ def cross_entropy_fwd_kernel( ...@@ -61,7 +62,7 @@ def cross_entropy_fwd_kernel(
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(
n_cols, (col_block_idx + 1) * BLOCK_SIZE n_cols, (col_block_idx + 1) * BLOCK_SIZE
): ):
logits_label = tl.load(logits_ptr + label_idx) logits_label = tl.load(logits_ptr + label_idx) * logit_scale
if HAS_SMOOTHING: if HAS_SMOOTHING:
loss = ( loss = (
(lse if not SPLIT else 0.0) (lse if not SPLIT else 0.0)
...@@ -94,6 +95,7 @@ def cross_entropy_bwd_kernel( ...@@ -94,6 +95,7 @@ def cross_entropy_bwd_kernel(
lse_ptr, lse_ptr,
labels_ptr, labels_ptr,
smoothing, smoothing,
logit_scale,
lse_square_scale, lse_square_scale,
ignored_index, ignored_index,
total_classes, total_classes,
...@@ -117,7 +119,7 @@ def cross_entropy_bwd_kernel( ...@@ -117,7 +119,7 @@ def cross_entropy_bwd_kernel(
dloss = 0.0 dloss = 0.0
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to(
tl.float32 tl.float32
) ) * logit_scale
lse = tl.load(lse_ptr + row_idx) lse = tl.load(lse_ptr + row_idx)
probs = tl.exp(logits - lse) probs = tl.exp(logits - lse)
probs += 2.0 * lse_square_scale * lse * probs probs += 2.0 * lse_square_scale * lse * probs
...@@ -128,16 +130,18 @@ def cross_entropy_bwd_kernel( ...@@ -128,16 +130,18 @@ def cross_entropy_bwd_kernel(
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative
else: else:
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols) tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)
class CrossEntropyLoss(torch.autograd.Function): class CrossEntropyLoss(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
logits, logits,
labels, labels,
smoothing, smoothing=0.0,
logit_scale=1.0,
lse_square_scale=0.0, lse_square_scale=0.0,
ignored_index=-100, ignored_index=-100,
inplace_backward=False, inplace_backward=False,
...@@ -177,6 +181,7 @@ class CrossEntropyLoss(torch.autograd.Function): ...@@ -177,6 +181,7 @@ class CrossEntropyLoss(torch.autograd.Function):
logits, logits,
labels, labels,
smoothing, smoothing,
logit_scale,
lse_square_scale, lse_square_scale,
ignored_index, ignored_index,
total_classes, total_classes,
...@@ -219,6 +224,7 @@ class CrossEntropyLoss(torch.autograd.Function): ...@@ -219,6 +224,7 @@ class CrossEntropyLoss(torch.autograd.Function):
ctx.save_for_backward(logits, lse, labels) ctx.save_for_backward(logits, lse, labels)
ctx.smoothing = smoothing ctx.smoothing = smoothing
ctx.logit_scale = logit_scale
ctx.lse_square_scale = lse_square_scale ctx.lse_square_scale = lse_square_scale
ctx.ignored_index = ignored_index ctx.ignored_index = ignored_index
ctx.total_classes = total_classes ctx.total_classes = total_classes
...@@ -244,6 +250,7 @@ class CrossEntropyLoss(torch.autograd.Function): ...@@ -244,6 +250,7 @@ class CrossEntropyLoss(torch.autograd.Function):
lse, lse,
labels, labels,
ctx.smoothing, ctx.smoothing,
ctx.logit_scale,
ctx.lse_square_scale, ctx.lse_square_scale,
ctx.ignored_index, ctx.ignored_index,
ctx.total_classes, ctx.total_classes,
...@@ -262,6 +269,7 @@ def cross_entropy_loss( ...@@ -262,6 +269,7 @@ def cross_entropy_loss(
logits: torch.Tensor, logits: torch.Tensor,
labels: torch.Tensor, labels: torch.Tensor,
label_smoothing: float = 0.0, label_smoothing: float = 0.0,
logit_scale: float = 1.0,
lse_square_scale: float = 0.0, lse_square_scale: float = 0.0,
ignored_index=-100, ignored_index=-100,
inplace_backward: bool = False, inplace_backward: bool = False,
...@@ -272,6 +280,7 @@ def cross_entropy_loss( ...@@ -272,6 +280,7 @@ def cross_entropy_loss(
logits: (batch, vocab_size) logits: (batch, vocab_size)
labels: (batch,) labels: (batch,)
label_smoothing: float label_smoothing: float
logit_scale: float. Multiply logits by this scale before calculating the loss.
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss.
This is also referred to as "z-loss". This is also referred to as "z-loss".
ignored_index: int. If labels == ignored_index, the loss is set to 0.0. ignored_index: int. If labels == ignored_index, the loss is set to 0.0.
...@@ -286,6 +295,7 @@ def cross_entropy_loss( ...@@ -286,6 +295,7 @@ def cross_entropy_loss(
logits, logits,
labels, labels,
label_smoothing, label_smoothing,
logit_scale,
lse_square_scale, lse_square_scale,
ignored_index, ignored_index,
inplace_backward, inplace_backward,
......
...@@ -17,11 +17,15 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 ...@@ -17,11 +17,15 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("inplace_backward", [False]) # @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [1e-2]) # @pytest.mark.parametrize("lse_square_scale", [1e-2])
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9]) @pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0]) # @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split @pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12]) # @pytest.mark.parametrize("vocab_size", [12])
def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_backward, dtype): def test_cross_entropy_loss(
vocab_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
):
device = "cuda" device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4) rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed # set seed
...@@ -38,13 +42,14 @@ def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_bac ...@@ -38,13 +42,14 @@ def test_cross_entropy_loss(vocab_size, smoothing, lse_square_scale, inplace_bac
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing) model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
model = CrossEntropyLoss( model = CrossEntropyLoss(
label_smoothing=smoothing, label_smoothing=smoothing,
logit_scale=logit_scale,
lse_square_scale=lse_square_scale, lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward, inplace_backward=inplace_backward,
) )
out = model(x, y) out = model(x, y)
out_pt = model_pt(x_pt.float(), y) out_pt = model_pt(x_pt.float() * logit_scale, y)
if lse_square_scale > 0.0: if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float(), dim=-1) lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean() out_pt += lse_square_scale * (lse_pt[y != -100] ** 2).mean()
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
......
...@@ -19,6 +19,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 ...@@ -19,6 +19,8 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("inplace_backward", [False]) # @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2]) @pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [0.0]) # @pytest.mark.parametrize("lse_square_scale", [0.0])
@pytest.mark.parametrize("logit_scale", [0.7])
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9]) @pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0]) # @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split @pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split
...@@ -26,7 +28,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8 ...@@ -26,7 +28,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
# @pytest.mark.parametrize("world_size", [1, 2]) # @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
def test_cross_entropy_loss_parallel( def test_cross_entropy_loss_parallel(
vocab_size, world_size, smoothing, lse_square_scale, inplace_backward, dtype vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
): ):
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
rtol, atol = ( rtol, atol = (
...@@ -59,15 +61,16 @@ def test_cross_entropy_loss_parallel( ...@@ -59,15 +61,16 @@ def test_cross_entropy_loss_parallel(
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none") model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
model = CrossEntropyLoss( model = CrossEntropyLoss(
label_smoothing=smoothing, label_smoothing=smoothing,
logit_scale=logit_scale,
reduction="none", reduction="none",
lse_square_scale=lse_square_scale, lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward, inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(), process_group=parallel_state.get_tensor_model_parallel_group(),
) )
out = model(x, y) out = model(x, y)
out_pt = model_pt(x_pt.float(), y) out_pt = model_pt(x_pt.float() * logit_scale, y)
if lse_square_scale > 0.0: if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float(), dim=-1) lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
out_pt += lse_square_scale * lse_pt.square() out_pt += lse_square_scale * lse_pt.square()
out_pt.masked_fill_(y == -100, 0.0) out_pt.masked_fill_(y == -100, 0.0)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6) assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment