Commit 07ffa1b7 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressing more comments

parent df521589
......@@ -453,8 +453,6 @@ class ParallelAttention(MegatronModule):
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
rotary_pos_emb = inputs[4] if inputs[4] is None \
else (inputs[4], inputs[5])
output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask)
return output_
......@@ -548,8 +546,10 @@ class ParallelAttention(MegatronModule):
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
rotary_pos_emb = rotary_pos_emb if isinstance(rotary_pos_emb, \
tuple) else ((rotary_pos_emb,) * 2)
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params:
batch_start = inference_params.batch_size_offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment