Unverified Commit 8d62d5c2 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Use fused implementation of RoPE in MultiHeadAttention (#658)



* Use fused implementation of RoPE in MultiHeadAttention
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix freqs dtype
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1e780946
...@@ -811,7 +811,7 @@ def _run_transformer_layer( ...@@ -811,7 +811,7 @@ def _run_transformer_layer(
rotary_pos_emb = None rotary_pos_emb = None
if RoPE: if RoPE:
PE = RotaryPositionEmbedding(dim=config.head_dim) PE = RotaryPositionEmbedding(dim=config.head_dim)
rotary_pos_emb = PE(config.max_seqlen_q).to(dtype=dtype, device="cuda") rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model # Set up model
block = ( block = (
......
...@@ -3625,8 +3625,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3625,8 +3625,8 @@ class MultiheadAttention(torch.nn.Module):
# apply relative positional encoding (rotary embedding) # apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format) query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format) key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)
context_layer = self.core_attention( context_layer = self.core_attention(
query_layer, query_layer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment