Commit df521589 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed the comments

parent 1e0e555c
...@@ -486,7 +486,7 @@ class ParallelAttention(MegatronModule): ...@@ -486,7 +486,7 @@ class ParallelAttention(MegatronModule):
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
is_first_step = False
if inference_params: if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_len
...@@ -497,6 +497,7 @@ class ParallelAttention(MegatronModule): ...@@ -497,6 +497,7 @@ class ParallelAttention(MegatronModule):
inf_max_seq_len, inf_max_batch_size) inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = ( inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory) inference_key_memory, inference_value_memory)
is_first_step = True
else: else:
inference_key_memory, inference_value_memory = \ inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number] inference_params.key_value_memory_dict[self.layer_number]
...@@ -741,14 +742,12 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -741,14 +742,12 @@ class ParallelTransformerLayer(MegatronModule):
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
self_attention_pos_emb = None self_attention_pos_emb = None
if rotary_pos_emb is not None:
self_attention_pos_emb = rotary_pos_emb
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention( self.self_attention(
layernorm_output, layernorm_output,
attention_mask, attention_mask,
inference_params=inference_params, inference_params=inference_params,
rotary_pos_emb=self_attention_pos_emb) rotary_pos_emb=rotary_pos_emb)
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment