"docs/vscode:/vscode.git/clone" did not exist on "24f6d637a24d0590cda4ab60f43a57b28a20c1a0"
Commit 945e5b94 authored by wxj's avatar wxj
Browse files

Update language_model.py, 添加torch.compile

parent 99a0c39e
Pipeline #2185 passed with stage
...@@ -15,6 +15,9 @@ from .module import MegatronModule ...@@ -15,6 +15,9 @@ from .module import MegatronModule
from .transformer import ParallelTransformer from .transformer import ParallelTransformer
from .utils import get_linear_layer, init_method_normal, scaled_init_method_normal from .utils import get_linear_layer, init_method_normal, scaled_init_method_normal
import torch._dynamo
torch._dynamo.config.suppress_errors = True
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
...@@ -468,6 +471,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -468,6 +471,7 @@ class TransformerLanguageModel(MegatronModule):
else: else:
raise Exception('Stage must have at least either encoder or decoder') raise Exception('Stage must have at least either encoder or decoder')
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward( def forward(
self, self,
enc_input_ids, enc_input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment