Commit accbe59f authored by wxj's avatar wxj
Browse files

Update transformer.py

parent 8ebbb6e3
Pipeline #1962 passed with stage
......@@ -1229,9 +1229,9 @@ class ParallelTransformerLayer(MegatronModule):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
from unsloth.kernels.rms_layernorm import fast_rms_layernorm
norm_output = self.input_norm(hidden_states) if not args.use_fast_rms_layernorm else fast_rms_layernorm(self.input_norm, hidden_states)
#norm_output = self.input_norm(hidden_states)
# from unsloth.kernels.rms_layernorm import fast_rms_layernorm
# norm_output = self.input_norm(hidden_states) if not args.use_fast_rms_layernorm else fast_rms_layernorm(self.input_norm, hidden_states)
norm_output = self.input_norm(hidden_states)
# Self attention.
attention_output, attention_bias = \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment