Commit 18c8cf95 authored by silencealiang's avatar silencealiang
Browse files

compile融合部分算子,提升性能

parent 38a61e7d
Pipeline #2201 passed with stage
......@@ -120,6 +120,7 @@ class VocabParallelCrossEntropy:
class _VocabParallelCrossEntropy(torch.autograd.Function):
@torch.compile(mode='max-autotune-no-cudagraphs')
@staticmethod
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
"""Vocab parallel cross entropy forward function."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment