Commit 30b92cf5 authored by mshoeybi's avatar mshoeybi
Browse files

resolved conflicts

parent 8cb389b8
...@@ -609,18 +609,6 @@ class ParallelTransformer(MegatronModule): ...@@ -609,18 +609,6 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
<<<<<<< HEAD
l = 0
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
=======
# Make sure memory is freed.
mpu.reset_checkpointed_activations_memory_buffer()
if self.activations_checkpoint_method == 'uniform': if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk. # the input activation of each divided chunk.
...@@ -629,6 +617,7 @@ class ParallelTransformer(MegatronModule): ...@@ -629,6 +617,7 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers), custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block': elif self.activations_checkpoint_method == 'block':
...@@ -639,13 +628,13 @@ class ParallelTransformer(MegatronModule): ...@@ -639,13 +628,13 @@ class ParallelTransformer(MegatronModule):
if l < self.activations_checkpoint_num_layers: if l < self.activations_checkpoint_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
raise ValueError("Invalid activation checkpoint method.") raise ValueError("Invalid activation checkpoint method.")
>>>>>>> main
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment