Commit 24369dd6 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

limit 'make_viewless_tensor()' to case of micro_batch_size == 1; added comment

parent 0c8e8cce
...@@ -557,6 +557,7 @@ class ParallelTransformer(MegatronModule): ...@@ -557,6 +557,7 @@ class ParallelTransformer(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.micro_batch_size = args.micro_batch_size
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_method = args.activations_checkpoint_method
...@@ -696,11 +697,17 @@ class ParallelTransformer(MegatronModule): ...@@ -696,11 +697,17 @@ class ParallelTransformer(MegatronModule):
hidden_states = self.input_tensor hidden_states = self.input_tensor
# Viewless tensor # Viewless tensor
hidden_states = mpu.make_viewless_tensor( # We only need to create a viewless tensor in the case of micro batch
hidden_states, # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
requires_grad = True, # above creates a view tensor, and '.contiguous()' is a pass-through.
keep_graph = True, # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
) # the need to make it viewless.
if self.micro_batch_size == 1:
hidden_states = mpu.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
)
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment