Commit df521589 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressed the comments

parent 1e0e555c
......@@ -486,7 +486,7 @@ class ParallelAttention(MegatronModule):
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
......@@ -497,6 +497,7 @@ class ParallelAttention(MegatronModule):
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
......@@ -741,14 +742,12 @@ class ParallelTransformerLayer(MegatronModule):
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
self_attention_pos_emb = None
if rotary_pos_emb is not None:
self_attention_pos_emb = rotary_pos_emb
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params,
rotary_pos_emb=self_attention_pos_emb)
rotary_pos_emb=rotary_pos_emb)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment