Unverified Commit aa0659e5 authored by Kate Cheng's avatar Kate Cheng Committed by GitHub
Browse files

Remove if-else and torch.tensor to meet cudagraph requirement (#1997)



* Remove if-else and torch.tensor to meet cudagraph requirement
Signed-off-by: default avatarKate Cheng <yunhsuanc@nvidia.com>

* Add is_cg_capturable flag to guard the if-else statement
Signed-off-by: default avatarKate Cheng <yunhsuanc@nvidia.com>

---------
Signed-off-by: default avatarKate Cheng <yunhsuanc@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6afca29c
...@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -29,6 +29,7 @@ class CrossEntropyFunction(torch.autograd.Function):
reduce_loss=False, reduce_loss=False,
dist_process_group=None, dist_process_group=None,
ignore_idx=-100, ignore_idx=-100,
is_cg_capturable=False,
): ):
""" """
The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each
...@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -47,10 +48,16 @@ class CrossEntropyFunction(torch.autograd.Function):
tensor: The computed loss. tensor: The computed loss.
""" """
loss, _input = triton_cross_entropy.cross_entropy_forward( loss, _input = triton_cross_entropy.cross_entropy_forward(
_input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx _input,
target,
label_smoothing,
reduce_loss,
dist_process_group,
ignore_idx,
) )
ctx.save_for_backward(_input.detach()) ctx.save_for_backward(_input.detach())
ctx.is_cg_capturable = is_cg_capturable
return loss return loss
@staticmethod @staticmethod
...@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function): ...@@ -66,13 +73,17 @@ class CrossEntropyFunction(torch.autograd.Function):
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
""" """
(_input,) = ctx.saved_tensors (_input,) = ctx.saved_tensors
_input = triton_cross_entropy.cross_entropy_backward(_input, grad_output) _input = triton_cross_entropy.cross_entropy_backward(
_input, grad_output, ctx.is_cg_capturable
)
return ( return (
_input, _input,
None, None,
None, None,
None, None,
None, None,
None,
None,
) )
......
...@@ -340,13 +340,17 @@ def cross_entropy_forward( ...@@ -340,13 +340,17 @@ def cross_entropy_forward(
return loss, _input return loss, _input
def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): def cross_entropy_backward(
_input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False
):
"""Backward implementation of cross entropy loss kernel""" """Backward implementation of cross entropy loss kernel"""
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): # Only check torch.equal when not in CUDA graph capturable mode
if not is_cg_capturable and torch.equal(
grad_output, torch.tensor(1.0, device=grad_output.device)
):
pass pass
else: else:
B, SQ, V = _input.shape B, SQ, V = _input.shape
n_rows = B * SQ n_rows = B * SQ
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment