Commit e2df3544 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton fa

parent 21c06ecb
...@@ -213,7 +213,7 @@ def _attn_fwd_inner( ...@@ -213,7 +213,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 256, "BLOCK_M": 256,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 2, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -223,7 +223,7 @@ def _attn_fwd_inner( ...@@ -223,7 +223,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 128, "BLOCK_N": 128,
"waves_per_eu": 2, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -233,7 +233,7 @@ def _attn_fwd_inner( ...@@ -233,7 +233,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 256, "BLOCK_M": 256,
"BLOCK_N": 128, "BLOCK_N": 128,
"waves_per_eu": 2, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -243,7 +243,7 @@ def _attn_fwd_inner( ...@@ -243,7 +243,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 1, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -253,7 +253,7 @@ def _attn_fwd_inner( ...@@ -253,7 +253,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 3, "waves_per_eu": 0,
"PRE_LOAD_V": True, "PRE_LOAD_V": True,
}, },
num_stages=1, num_stages=1,
...@@ -263,7 +263,7 @@ def _attn_fwd_inner( ...@@ -263,7 +263,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 3, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -273,7 +273,7 @@ def _attn_fwd_inner( ...@@ -273,7 +273,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 64, "BLOCK_M": 64,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 4, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -283,7 +283,7 @@ def _attn_fwd_inner( ...@@ -283,7 +283,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 32, "BLOCK_M": 32,
"BLOCK_N": 32, "BLOCK_N": 32,
"waves_per_eu": 4, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -296,7 +296,7 @@ def _attn_fwd_inner( ...@@ -296,7 +296,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 16, "BLOCK_M": 16,
"BLOCK_N": 16, "BLOCK_N": 16,
"waves_per_eu": 1, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
......
...@@ -130,7 +130,7 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -130,7 +130,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# flag to control if vllm should use triton flash attention # flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN": "VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment