"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "9f9b48168f106a172b28aeb44cb12f2c2c232181"
Commit fd60eedd authored by wenjh's avatar wenjh
Browse files

Support GLM params


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 99a8a0c5
Pipeline #3435 failed with stages
in 0 seconds
...@@ -83,6 +83,8 @@ model_configs_base = { ...@@ -83,6 +83,8 @@ model_configs_base = {
"base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048),
"base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048),
"base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048),
"base_7_0": ModelConfig(4, 1226, 32, 256),
} }
...@@ -277,6 +279,8 @@ model_configs_mla = { ...@@ -277,6 +279,8 @@ model_configs_mla = {
# "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
# "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference # "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference
"mla_4_0": ModelConfig(4, 1226, 32, 256),
} }
...@@ -332,6 +336,8 @@ model_configs_mask = { ...@@ -332,6 +336,8 @@ model_configs_mask = {
"mask_10_1": ModelConfig( "mask_10_1": ModelConfig(
2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right" 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right"
), ),
"mask_11_0": ModelConfig(4, 1226, 32, 256, attn_mask_type="padding_causal"),
} }
......
...@@ -446,6 +446,7 @@ class FlashAttention(torch.nn.Module): ...@@ -446,6 +446,7 @@ class FlashAttention(torch.nn.Module):
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
return_qk_max: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -470,6 +471,8 @@ class FlashAttention(torch.nn.Module): ...@@ -470,6 +471,8 @@ class FlashAttention(torch.nn.Module):
if not self.logger.hasHandlers(): if not self.logger.hasHandlers():
self.logger.addHandler(attn_log._stream_handler) self.logger.addHandler(attn_log._stream_handler)
self.return_qk_max = return_qk_max
@classmethod @classmethod
def _get_cached_page_offsets( def _get_cached_page_offsets(
cls, split_factor: int, device: torch.device, dtype: torch.dtype cls, split_factor: int, device: torch.device, dtype: torch.dtype
...@@ -724,6 +727,7 @@ class FlashAttention(torch.nn.Module): ...@@ -724,6 +727,7 @@ class FlashAttention(torch.nn.Module):
alibi_slopes is None alibi_slopes is None
), "Alibi slope bias addition is not supported with context parallelism." ), "Alibi slope bias addition is not supported with context parallelism."
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
assert (not self.return_qk_max), "attn_forward_func_with_cp does not support returning qk_max yet."
output = attn_forward_func_with_cp( output = attn_forward_func_with_cp(
self.training, self.training,
query_layer, query_layer,
...@@ -821,16 +825,29 @@ class FlashAttention(torch.nn.Module): ...@@ -821,16 +825,29 @@ class FlashAttention(torch.nn.Module):
allow_negative_entries=False, allow_negative_entries=False,
) )
fa_optional_forward_kwargs["block_table"] = remapped_block_table fa_optional_forward_kwargs["block_table"] = remapped_block_table
output = func( if not self.return_qk_max:
query_layer, output = func(
key_layer, query_layer,
value_layer, key_layer,
*fa_optional_forward_args_thd, value_layer,
self.attention_dropout if self.training else 0.0, *fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale, self.attention_dropout if self.training else 0.0,
causal="causal" in attn_mask_type, softmax_scale=self.softmax_scale,
**fa_optional_forward_kwargs, causal="causal" in attn_mask_type,
) **fa_optional_forward_kwargs,
)
else:
output, qk_max = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
self.attention_dropout if self.training else 0.0,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
return_qkmax=True,
**fa_optional_forward_kwargs,
)
else: else:
fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs = {}
fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["window_size"] = window_size
...@@ -886,15 +903,27 @@ class FlashAttention(torch.nn.Module): ...@@ -886,15 +903,27 @@ class FlashAttention(torch.nn.Module):
for x in [query_layer, key_layer, value_layer] for x in [query_layer, key_layer, value_layer]
) )
try: try:
output = func( if not self.return_qk_max:
query_layer, output = func(
key_layer, query_layer,
value_layer, key_layer,
*fa_optional_forward_args_thd, value_layer,
softmax_scale=self.softmax_scale, *fa_optional_forward_args_thd,
causal="causal" in attn_mask_type, softmax_scale=self.softmax_scale,
**fa_3_optional_forward_kwargs, causal="causal" in attn_mask_type,
) **fa_3_optional_forward_kwargs,
)
else:
output, qk_max = func(
query_layer,
key_layer,
value_layer,
*fa_optional_forward_args_thd,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
return_qkmax=True,
**fa_3_optional_forward_kwargs,
)
if isinstance(output, (List, Tuple)): if isinstance(output, (List, Tuple)):
output = output[0] output = output[0]
except TypeError as e: except TypeError as e:
...@@ -956,6 +985,10 @@ class FlashAttention(torch.nn.Module): ...@@ -956,6 +985,10 @@ class FlashAttention(torch.nn.Module):
elif q_format == "thd": elif q_format == "thd":
# thd -> t(hd) # thd -> t(hd)
output = output.reshape(output.shape[0], -1) output = output.reshape(output.shape[0], -1)
if self.return_qk_max:
return output.contiguous(), qk_max
return output.contiguous() return output.contiguous()
......
...@@ -223,6 +223,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -223,6 +223,7 @@ class DotProductAttention(TransformerEngineBaseModule):
cp_stream: torch.cuda.Stream = None, cp_stream: torch.cuda.Stream = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
return_qk_max: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -251,6 +252,8 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -251,6 +252,8 @@ class DotProductAttention(TransformerEngineBaseModule):
self.cp_stream = cp_stream self.cp_stream = cp_stream
self.cp_comm_type = cp_comm_type self.cp_comm_type = cp_comm_type
self.return_qk_max = return_qk_max
self.hidden_size_per_attention_head_k = ( self.hidden_size_per_attention_head_k = (
kv_channels if isinstance(kv_channels, int) else kv_channels[0] kv_channels if isinstance(kv_channels, int) else kv_channels[0]
) )
...@@ -317,6 +320,7 @@ class DotProductAttention(TransformerEngineBaseModule): ...@@ -317,6 +320,7 @@ class DotProductAttention(TransformerEngineBaseModule):
attention_type=attention_type, attention_type=attention_type,
layer_number=layer_number, layer_number=layer_number,
deterministic=self.deterministic, deterministic=self.deterministic,
return_qk_max=self.return_qk_max,
**attn_kwargs, **attn_kwargs,
) )
......
...@@ -507,17 +507,18 @@ def get_attention_backend( ...@@ -507,17 +507,18 @@ def get_attention_backend(
and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0))
) )
): ):
if FlashAttentionUtils.is_installed: if not (IS_HIP_EXTENSION and head_dim_qk == 256 and head_dim_v == 256):
logger.debug( if FlashAttentionUtils.is_installed:
"Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. " logger.debug(
"Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " "Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. "
"head_dim_qk <= 256 (>192 requires sm80/90/100+). " "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", "head_dim_qk <= 256 (>192 requires sm80/90/100+). "
head_dim_qk, "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
head_dim_v, head_dim_qk,
".".join([str(i) for i in device_compute_capability]), head_dim_v,
) ".".join([str(i) for i in device_compute_capability]),
use_flash_attention_2 = False )
use_flash_attention_2 = False
if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128):
if FlashAttentionUtils.v3_is_installed: if FlashAttentionUtils.v3_is_installed:
logger.debug("Disabling FlashAttention 3 for head_dim > 128") logger.debug("Disabling FlashAttention 3 for head_dim > 128")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment