Commit 945e5b94 authored by wxj's avatar wxj
Browse files

Update language_model.py, 添加torch.compile

parent 99a0c39e
Pipeline #2185 passed with stage
......@@ -15,6 +15,9 @@ from .module import MegatronModule
from .transformer import ParallelTransformer
from .utils import get_linear_layer, init_method_normal, scaled_init_method_normal
import torch._dynamo
torch._dynamo.config.suppress_errors = True
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
......@@ -468,6 +471,7 @@ class TransformerLanguageModel(MegatronModule):
else:
raise Exception('Stage must have at least either encoder or decoder')
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(
self,
enc_input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment