Commit 07ffa1b7 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

addressing more comments

parent df521589
...@@ -453,8 +453,6 @@ class ParallelAttention(MegatronModule): ...@@ -453,8 +453,6 @@ class ParallelAttention(MegatronModule):
key_layer = inputs[1] key_layer = inputs[1]
value_layer = inputs[2] value_layer = inputs[2]
attention_mask = inputs[3] attention_mask = inputs[3]
rotary_pos_emb = inputs[4] if inputs[4] is None \
else (inputs[4], inputs[5])
output_ = self.core_attention(query_layer, key_layer, output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask) value_layer, attention_mask)
return output_ return output_
...@@ -548,8 +546,10 @@ class ParallelAttention(MegatronModule): ...@@ -548,8 +546,10 @@ class ParallelAttention(MegatronModule):
# duplicate the pos_emb for self attention # duplicate the pos_emb for self attention
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
rotary_pos_emb = rotary_pos_emb if isinstance(rotary_pos_emb, \ if isinstance(rotary_pos_emb, tuple):
tuple) else ((rotary_pos_emb,) * 2) rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params: if inference_params:
batch_start = inference_params.batch_size_offset batch_start = inference_params.batch_size_offset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment