Commit c5369391 authored by wxj's avatar wxj
Browse files

Update transformer.py, 关闭整个模型的torch.compile

parent 473449d8
Pipeline #2221 failed with stages
in 0 seconds
...@@ -165,7 +165,7 @@ class ParallelMLP(MegatronModule): ...@@ -165,7 +165,7 @@ class ParallelMLP(MegatronModule):
is_expert=is_expert, is_expert=is_expert,
) )
@torch.compile(mode="max-autotune-no-cudagraphs") # @torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states): def forward(self, hidden_states):
# [s, b, 4hp] # [s, b, 4hp]
...@@ -1213,7 +1213,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -1213,7 +1213,7 @@ class ParallelTransformerLayer(MegatronModule):
return retriever_output, norm_input, norm_output return retriever_output, norm_input, norm_output
@torch.compile(mode="max-autotune-no-cudagraphs") # @torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
retriever_input=None, retriever_input=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment