Commit d16e2a24 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

further clarified viewless tensor comment in transformer.py

parent 24369dd6
...@@ -557,7 +557,6 @@ class ParallelTransformer(MegatronModule): ...@@ -557,7 +557,6 @@ class ParallelTransformer(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.micro_batch_size = args.micro_batch_size
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.activations_checkpoint_method = args.activations_checkpoint_method
...@@ -696,22 +695,32 @@ class ParallelTransformer(MegatronModule): ...@@ -696,22 +695,32 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor() # See set_input_tensor()
hidden_states = self.input_tensor hidden_states = self.input_tensor
# Viewless tensor # Viewless tensor.
# We only need to create a viewless tensor in the case of micro batch # - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()' # size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through. # above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless. # the need to make it viewless.
if self.micro_batch_size == 1: #
hidden_states = mpu.make_viewless_tensor( # However, we don't explicitly check mbs == 1 here because
hidden_states, # make_viewless_tensor() has negligible overhead when its input
requires_grad = True, # is already viewless.
keep_graph = True, #
) # - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
)
# Transpose encoder output.
if encoder_output is not None: if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous() encoder_output = encoder_output.transpose(0, 1).contiguous()
# Forward pass.
if self.activations_checkpoint_method is not None: if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask, attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment