Commit ee5cd2ec authored by silencealiang's avatar silencealiang
Browse files

Update cross_entropy.py

parent 491133b9
Pipeline #2553 passed with stage
...@@ -120,8 +120,8 @@ class VocabParallelCrossEntropy: ...@@ -120,8 +120,8 @@ class VocabParallelCrossEntropy:
class _VocabParallelCrossEntropy(torch.autograd.Function): class _VocabParallelCrossEntropy(torch.autograd.Function):
@torch.compile(mode='max-autotune-no-cudagraphs')
@staticmethod @staticmethod
@torch.compile(mode='max-autotune-no-cudagraphs')
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0): def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
"""Vocab parallel cross entropy forward function.""" """Vocab parallel cross entropy forward function."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment