Unverified Commit 1bd45b97 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Expose interleaved parameter for rotary position embeddings (#1783)


Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d2973cb8
......@@ -116,6 +116,8 @@ class MultiheadAttention(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False`
whether to use interleaved rotary position embeddings.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
device : Union[torch.device, str], default = "cuda"
......@@ -201,6 +203,7 @@ class MultiheadAttention(torch.nn.Module):
fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
rotary_pos_interleaved: bool = False,
ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False,
......@@ -239,6 +242,7 @@ class MultiheadAttention(torch.nn.Module):
if not fuse_qkv_params:
qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved
self.rotary_pos_interleaved = rotary_pos_interleaved
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
if layer_number is not None:
......@@ -775,6 +779,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=cu_seqlens_q,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved,
)
key_layer = apply_rotary_pos_emb(
key_layer,
......@@ -784,6 +789,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=cu_seqlens_kv,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved,
)
# ===========================
......
......@@ -168,6 +168,8 @@ class TransformerLayer(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False`
whether to use interleaved rotary position embeddings.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
......@@ -267,6 +269,7 @@ class TransformerLayer(torch.nn.Module):
drop_path_rate: float = 0.0,
set_parallel_mode: bool = False,
fuse_qkv_params: bool = False,
rotary_pos_interleaved: bool = False,
zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
......@@ -363,6 +366,7 @@ class TransformerLayer(torch.nn.Module):
"fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved": qkv_weight_interleaved,
"rotary_pos_interleaved": rotary_pos_interleaved,
"ub_bulk_wgrad": ub_bulk_wgrad,
"ub_bulk_dgrad": ub_bulk_dgrad,
"ub_overlap_ag": ub_overlap_ag,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment