Unverified Commit 1bd45b97 authored by Peter St. John's avatar Peter St. John Committed by GitHub
Browse files

Expose interleaved parameter for rotary position embeddings (#1783)


Signed-off-by: default avatarPeter St. John <pstjohn@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent d2973cb8
...@@ -116,6 +116,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -116,6 +116,8 @@ class MultiheadAttention(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`. using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False`
whether to use interleaved rotary position embeddings.
bias : bool, default = `True` bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases. if set to `False`, the transformer layer will not learn any additive biases.
device : Union[torch.device, str], default = "cuda" device : Union[torch.device, str], default = "cuda"
...@@ -201,6 +203,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -201,6 +203,7 @@ class MultiheadAttention(torch.nn.Module):
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
rotary_pos_interleaved: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_rs_dgrad: bool = False, ub_overlap_rs_dgrad: bool = False,
...@@ -239,6 +242,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -239,6 +242,7 @@ class MultiheadAttention(torch.nn.Module):
if not fuse_qkv_params: if not fuse_qkv_params:
qkv_weight_interleaved = False qkv_weight_interleaved = False
self.qkv_weight_interleaved = qkv_weight_interleaved self.qkv_weight_interleaved = qkv_weight_interleaved
self.rotary_pos_interleaved = rotary_pos_interleaved
assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" assert attention_type in AttnTypes, f"attention_type {attention_type} not supported"
if layer_number is not None: if layer_number is not None:
...@@ -775,6 +779,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -775,6 +779,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=cu_seqlens_q, cu_seqlens=cu_seqlens_q,
cp_size=self.cp_size, cp_size=self.cp_size,
cp_rank=self.cp_rank, cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved,
) )
key_layer = apply_rotary_pos_emb( key_layer = apply_rotary_pos_emb(
key_layer, key_layer,
...@@ -784,6 +789,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -784,6 +789,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=cu_seqlens_kv, cu_seqlens=cu_seqlens_kv,
cp_size=self.cp_size, cp_size=self.cp_size,
cp_rank=self.cp_rank, cp_rank=self.cp_rank,
interleaved=self.rotary_pos_interleaved,
) )
# =========================== # ===========================
......
...@@ -168,6 +168,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -168,6 +168,8 @@ class TransformerLayer(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`. using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False`
whether to use interleaved rotary position embeddings.
bias : bool, default = `True` bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases. if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu' activation : str, default = 'gelu'
...@@ -267,6 +269,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -267,6 +269,7 @@ class TransformerLayer(torch.nn.Module):
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
set_parallel_mode: bool = False, set_parallel_mode: bool = False,
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
rotary_pos_interleaved: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False, ub_tp_comm_overlap: bool = False,
...@@ -363,6 +366,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -363,6 +366,7 @@ class TransformerLayer(torch.nn.Module):
"fuse_qkv_params": fuse_qkv_params, "fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved": qkv_weight_interleaved, "qkv_weight_interleaved": qkv_weight_interleaved,
"rotary_pos_interleaved": rotary_pos_interleaved,
"ub_bulk_wgrad": ub_bulk_wgrad, "ub_bulk_wgrad": ub_bulk_wgrad,
"ub_bulk_dgrad": ub_bulk_dgrad, "ub_bulk_dgrad": ub_bulk_dgrad,
"ub_overlap_ag": ub_overlap_ag, "ub_overlap_ag": ub_overlap_ag,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment