Commit fd60eedd authored by wenjh's avatar wenjh
Browse files

Support GLM params


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 99a8a0c5
Pipeline #3435 failed with stages
in 0 seconds
......@@ -83,6 +83,8 @@ model_configs_base = {
"base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
"base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
"base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
"base_7_0": ModelConfig(4, 1226, 32, 256),
}
......@@ -277,6 +279,8 @@ model_configs_mla = {
# "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
# "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
"mla_4_0": ModelConfig(4, 1226, 32, 256),
}
......@@ -332,6 +336,8 @@ model_configs_mask = {
"mask_10_1": ModelConfig(
2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
),
"mask_11_0": ModelConfig(4, 1226, 32, 256, attn_mask_type="padding_causal"),
}
......
......@@ -446,6 +446,7 @@ class FlashAttention(torch.nn.Module):
attention_type: str = "self",
layer_number: Optional[int] = None,
deterministic: bool = False,
return_qk_max: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -470,6 +471,8 @@ class FlashAttention(torch.nn.Module):
if not self.logger.hasHandlers():
self.logger.addHandler(attn_log._stream_handler)
self.return_qk_max = return_qk_max
@classmethod
def _get_cached_page_offsets(
cls, split_factor: int, device: torch.device, dtype: torch.dtype
......@@ -724,6 +727,7 @@ class FlashAttention(torch.nn.Module):
alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism."
with self.attention_dropout_ctx():
assert (not self.return_qk_max), "attn_forward_func_with_cp does not support returning qk_max yet."
output = attn_forward_func_with_cp(
self.training,
query_layer,
......@@ -821,16 +825,29 @@ class FlashAttention(torch.nn.Module):
allow_negative_entries=False,
)
fa_optional_forward_kwargs["block_table"] = remapped_block_table
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
if not self.return_qk_max:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_optional_forward_kwargs,
)
else:
output, qk_max = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
return_qkmax=True,
**fa_optional_forward_kwargs,
)
else:
fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size
......@@ -886,15 +903,27 @@ class FlashAttention(torch.nn.Module):
for x in [query_layer, key_layer, value_layer]
)
try:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_3_optional_forward_kwargs,
)
if not self.return_qk_max:
output = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
**fa_3_optional_forward_kwargs,
)
else:
output, qk_max = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
return_qkmax=True,
**fa_3_optional_forward_kwargs,
)
if isinstance(output, (List, Tuple)):
output = output[0]
except TypeError as e:
......@@ -956,6 +985,10 @@ class FlashAttention(torch.nn.Module):
elif q_format == "thd":
# thd -> t(hd)
output = output.reshape(output.shape[0], -1)
if self.return_qk_max:
return output.contiguous(), qk_max
return output.contiguous()
......
......@@ -223,6 +223,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None,
return_qk_max: Optional[bool] = False,
) -> None:
super().__init__()
......@@ -251,6 +252,8 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type
self.return_qk_max = return_qk_max
self.hidden_size_per_attention_head_k = (
kv_channels if isinstance(kv_channels, int) else kv_channels[0]
)
......@@ -317,6 +320,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type,
layer_number=layer_number,
deterministic=self.deterministic,
return_qk_max=self.return_qk_max,
**attn_kwargs,
)
......
......@@ -507,17 +507,18 @@ def get_attention_backend(
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
)
):
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention_2 = False
if not (IS_HIP_EXTENSION and head_dim_qk == 256 and head_dim_v == 256):
if FlashAttentionUtils.is_installed:
logger.debug(
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_qk,
head_dim_v,
".".join([str(i) for i in device_compute_capability]),
)
use_flash_attention_2 = False
if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128):
if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for head_dim > 128")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment