Unverified Commit b8d453ef authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Merge `k_channels` and `v_channels` back to `kv_channels` (#1094)



* merge k_channels and v_channels back to kv_channels and accept a tuple
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix isinstance call
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MLA tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent ec49a52b
...@@ -902,7 +902,7 @@ def _run_dot_product_attention( ...@@ -902,7 +902,7 @@ def _run_dot_product_attention(
# Set up model # Set up model
block = DotProductAttention( block = DotProductAttention(
config.num_heads, config.num_heads,
config.head_dim_qk, (config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups, num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p, attention_dropout=config.dropout_p,
qkv_format=qkv_format, qkv_format=qkv_format,
......
...@@ -1083,7 +1083,7 @@ def test_export_core_attention( ...@@ -1083,7 +1083,7 @@ def test_export_core_attention(
model = te.attention.DotProductAttention( model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
k_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
qkv_format=qkv_format, qkv_format=qkv_format,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
......
...@@ -5177,10 +5177,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5177,10 +5177,9 @@ class DotProductAttention(TransformerEngineBaseModule):
---------- ----------
num_attention_heads : int num_attention_heads : int
number of attention heads in the transformer layer. number of attention heads in the transformer layer.
k_channels : int kv_channels : Union[int, Tuple[int, int]]
number of channels per attention head in key. the head size in key and value tensors. If the same, :attr:`kv_channels` can be
v_channels : Optional[int] = None an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
number of channels per attention head in value.
num_gqa_groups : Optional[int] = None num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer. number of GQA groups in the transformer layer.
Grouped Query Attention is described in Grouped Query Attention is described in
...@@ -5242,7 +5241,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5242,7 +5241,7 @@ class DotProductAttention(TransformerEngineBaseModule):
For that, please use `get_qkv_layout` to gain the layout information. For that, please use `get_qkv_layout` to gain the layout information.
softmax_scale: Optional[float], default = `None` softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to softmax scale for the attention scores. If `None`, defaults to
`1.0 / math.sqrt(kv_channels)`. `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
Parallelism parameters Parallelism parameters
---------------------- ----------------------
...@@ -5266,8 +5265,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5266,8 +5265,7 @@ class DotProductAttention(TransformerEngineBaseModule):
def __init__( def __init__(
self, self,
num_attention_heads: int, num_attention_heads: int,
k_channels: int, kv_channels: Union[int, Tuple[int, int]],
v_channels: Optional[int] = None,
num_gqa_groups: Optional[int] = None, num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0, attention_dropout: float = 0.0,
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
...@@ -5310,8 +5308,12 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5310,8 +5308,12 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_global_ranks = cp_global_ranks self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream self.cp_stream = cp_stream
self.hidden_size_per_attention_head = k_channels self.hidden_size_per_attention_head_k = (
self.v_channels = k_channels if v_channels is None else v_channels kv_channels if isinstance(kv_channels, int) else kv_channels[0]
)
self.hidden_size_per_attention_head_v = (
kv_channels if isinstance(kv_channels, int) else kv_channels[1]
)
self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
...@@ -5329,7 +5331,9 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5329,7 +5331,9 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_dropout_ctx = self.rng_states_tracker.fork attention_dropout_ctx = self.rng_states_tracker.fork
if softmax_scale is None: if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(k_channels) softmax_scale = 1.0 / math.sqrt(
kv_channels if isinstance(kv_channels, int) else kv_channels[0]
)
self.deterministic = ( self.deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
...@@ -5628,6 +5632,14 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -5628,6 +5632,14 @@ class DotProductAttention(TransformerEngineBaseModule):
assert ( assert (
key_layer.shape[:-1] == value_layer.shape[:-1] key_layer.shape[:-1] == value_layer.shape[:-1]
), "Keys and values must have the same batch size, sequence length and number of heads!" ), "Keys and values must have the same batch size, sequence length and number of heads!"
assert (
key_layer.shape[-1] == self.hidden_size_per_attention_head_k
), f"Keys have head_dim = {key_layer.shape[-1]}, "
"but expected head_dim = {self.hidden_size_per_attention_head_k}!"
assert (
value_layer.shape[-1] == self.hidden_size_per_attention_head_v
), f"Values have head_dim = {value_layer.shape[-1]}, "
"but expected head_dim = {self.hidden_size_per_attention_head_v}!"
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = self.attn_mask_type attn_mask_type = self.attn_mask_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment