Commit 031a4157 authored by wxj's avatar wxj
Browse files

Update transformer.py

parent a45f4c2a
Pipeline #2227 passed with stage
......@@ -165,7 +165,7 @@ class ParallelMLP(MegatronModule):
is_expert=is_expert,
)
@torch.compile(mode="max-autotune-no-cudagraphs")
# @torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states):
# [s, b, 4hp]
......@@ -1213,7 +1213,7 @@ class ParallelTransformerLayer(MegatronModule):
return retriever_output, norm_input, norm_output
@torch.compile(mode="max-autotune-no-cudagraphs")
# @torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
retriever_input=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment