"tests/pytorch/test_quantized_tensor.py" did not exist on "d99142a0177a2462cbda07a31aaa8e68b4e85461"
Unverified Commit b8d453ef authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Merge `k_channels` and `v_channels` back to `kv_channels` (#1094)



* merge k_channels and v_channels back to kv_channels and accept a tuple
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix isinstance call
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix MLA tests
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent ec49a52b
......@@ -902,7 +902,7 @@ def _run_dot_product_attention(
# Set up model
block = DotProductAttention(
config.num_heads,
config.head_dim_qk,
(config.head_dim_qk, config.head_dim_v),
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
......
......@@ -1083,7 +1083,7 @@ def test_export_core_attention(
model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads,
k_channels=kv_channels,
kv_channels=kv_channels,
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
......
......@@ -5177,10 +5177,9 @@ class DotProductAttention(TransformerEngineBaseModule):
----------
num_attention_heads : int
number of attention heads in the transformer layer.
k_channels : int
number of channels per attention head in key.
v_channels : Optional[int] = None
number of channels per attention head in value.
kv_channels : Union[int, Tuple[int, int]]
the head size in key and value tensors. If the same, :attr:`kv_channels` can be
an integer; if not, :attr:`kv_channels` should be a tuple of two integers.
num_gqa_groups : Optional[int] = None
number of GQA groups in the transformer layer.
Grouped Query Attention is described in
......@@ -5242,7 +5241,7 @@ class DotProductAttention(TransformerEngineBaseModule):
For that, please use `get_qkv_layout` to gain the layout information.
softmax_scale: Optional[float], default = `None`
softmax scale for the attention scores. If `None`, defaults to
`1.0 / math.sqrt(kv_channels)`.
`1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`.
Parallelism parameters
----------------------
......@@ -5266,8 +5265,7 @@ class DotProductAttention(TransformerEngineBaseModule):
def __init__(
self,
num_attention_heads: int,
k_channels: int,
v_channels: Optional[int] = None,
kv_channels: Union[int, Tuple[int, int]],
num_gqa_groups: Optional[int] = None,
attention_dropout: float = 0.0,
qkv_format: str = "sbhd",
......@@ -5310,8 +5308,12 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
self.hidden_size_per_attention_head = k_channels
self.v_channels = k_channels if v_channels is None else v_channels
self.hidden_size_per_attention_head_k = (
kv_channels if isinstance(kv_channels, int) else kv_channels[0]
)
self.hidden_size_per_attention_head_v = (
kv_channels if isinstance(kv_channels, int) else kv_channels[1]
)
self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups
self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size)
......@@ -5329,7 +5331,9 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_dropout_ctx = self.rng_states_tracker.fork
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(k_channels)
softmax_scale = 1.0 / math.sqrt(
kv_channels if isinstance(kv_channels, int) else kv_channels[0]
)
self.deterministic = (
not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))
......@@ -5628,6 +5632,14 @@ class DotProductAttention(TransformerEngineBaseModule):
assert (
key_layer.shape[:-1] == value_layer.shape[:-1]
), "Keys and values must have the same batch size, sequence length and number of heads!"
assert (
key_layer.shape[-1] == self.hidden_size_per_attention_head_k
), f"Keys have head_dim = {key_layer.shape[-1]}, "
"but expected head_dim = {self.hidden_size_per_attention_head_k}!"
assert (
value_layer.shape[-1] == self.hidden_size_per_attention_head_v
), f"Values have head_dim = {value_layer.shape[-1]}, "
"but expected head_dim = {self.hidden_size_per_attention_head_v}!"
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment