Commit 473449d8 authored by wxj's avatar wxj
Browse files

Update transformer.py, 添加torch.compile

parent 50fe58fa
Pipeline #2184 failed with stages
in 0 seconds
......@@ -1213,6 +1213,7 @@ class ParallelTransformerLayer(MegatronModule):
return retriever_output, norm_input, norm_output
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
retriever_input=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment