Commit 30b92cf5 authored by mshoeybi's avatar mshoeybi
Browse files

resolved conflicts

parent 8cb389b8
......@@ -609,18 +609,6 @@ class ParallelTransformer(MegatronModule):
return x_
return custom_forward
<<<<<<< HEAD
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
=======
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
......@@ -629,6 +617,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
......@@ -639,13 +628,13 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
>>>>>>> main
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment