Commit 24369dd6 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

limit 'make_viewless_tensor()' to case of micro_batch_size == 1; added comment

parent 0c8e8cce
...@@ -557,6 +557,7 @@ class ParallelTransformer(MegatronModule): ...@@ -557,6 +557,7 @@ class ParallelTransformer(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.micro_batch_size = args.micro_batch_size
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_method = args.activations_checkpoint_method
...@@ -696,6 +697,12 @@ class ParallelTransformer(MegatronModule): ...@@ -696,6 +697,12 @@ class ParallelTransformer(MegatronModule):
hidden_states = self.input_tensor hidden_states = self.input_tensor
# Viewless tensor # Viewless tensor
# We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
if self.micro_batch_size == 1:
hidden_states = mpu.make_viewless_tensor( hidden_states = mpu.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad = True, requires_grad = True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment