Unverified Commit 8d62d5c2 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Use fused implementation of RoPE in MultiHeadAttention (#658)



* Use fused implementation of RoPE in MultiHeadAttention
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix freqs dtype
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1e780946
......@@ -811,7 +811,7 @@ def _run_transformer_layer(
rotary_pos_emb = None
if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda")
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model
block = (
......
......@@ -3625,8 +3625,8 @@ class MultiheadAttention(torch.nn.Module):
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format)
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
context_layer = self.core_attention(
query_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment