Commit a0e873f8 authored by silencealiang's avatar silencealiang
Browse files

增加compile,减少图编译,减少cpu开销

parent 06b52e5b
Pipeline #2222 passed with stage
......@@ -8,3 +8,17 @@ jit_fuser = torch.jit.script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if is_torch_min_version("2.2.0a0"):
jit_fuser = torch.compile
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo
if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
\ No newline at end of file
......@@ -237,7 +237,7 @@ class VocabParallelEmbedding(torch.nn.Module):
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
@torch.compile(mode='max-autotune-no-cudagraphs')
def forward(self, input_):
"""Forward.
......
......@@ -462,13 +462,13 @@ class _AllToAll(torch.autograd.Function):
# -----------------
# Helper functions.
# -----------------
from megatron.core.jit import no_torch_dynamo
def copy_to_tensor_model_parallel_region(input_):
"""Wrapper for autograd function: forward: copy, backward allreduce"""
return _CopyToModelParallelRegion.apply(input_)
@no_torch_dynamo()
def reduce_from_tensor_model_parallel_region(input_):
"""Wrapper for autograd function: forward: all reduce, backward copy"""
return _ReduceFromModelParallelRegion.apply(input_)
......@@ -501,7 +501,7 @@ def gather_from_sequence_parallel_region(
input_, tensor_parallel_output_grad, group, output_split_sizes, use_global_buffer
)
@no_torch_dynamo()
def reduce_scatter_to_sequence_parallel_region(
input_, group=None, input_split_sizes=None, use_global_buffer=False
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment