Unverified Commit aa0659e5 authored by Kate Cheng's avatar Kate Cheng Committed by GitHub
Browse files

Remove if-else and torch.tensor to meet cudagraph requirement (#1997)



* Remove if-else and torch.tensor to meet cudagraph requirement
Signed-off-by: default avatarKate Cheng <yunhsuanc@nvidia.com>

* Add is_cg_capturable flag to guard the if-else statement
Signed-off-by: default avatarKate Cheng <yunhsuanc@nvidia.com>

---------
Signed-off-by: default avatarKate Cheng <yunhsuanc@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6afca29c
......@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function):
reduce_loss=False,
dist_process_group=None,
ignore_idx=-100,
is_cg_capturable=False,
):
"""
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
......@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function):
tensor: The computed loss.
"""
loss, _input = triton_cross_entropy.cross_entropy_forward(
_input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx
_input,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
)
ctx.save_for_backward(_input.detach())
ctx.is_cg_capturable = is_cg_capturable
return loss
@staticmethod
......@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function):
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
(_input,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward(_input, grad_output)
_input = triton_cross_entropy.cross_entropy_backward(
_input, grad_output, ctx.is_cg_capturable
)
return (
_input,
None,
None,
None,
None,
None,
None,
)
......
......@@ -340,13 +340,17 @@ def cross_entropy_forward(
return loss, _input
def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
def cross_entropy_backward(
_input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False
):
"""Backward implementation of cross entropy loss kernel"""
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
# Only check torch.equal when not in CUDA graph capturable mode
if not is_cg_capturable and torch.equal(
grad_output, torch.tensor(1.0, device=grad_output.device)
):
pass
else:
B, SQ, V = _input.shape
n_rows = B * SQ
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment