Commit c284a667 authored by Stefano Castagnetta's avatar Stefano Castagnetta Committed by khluu
Browse files

[Bugfix] Restrict TRTLLM attention to SM100, fixing GB300 (SM103) hang (#38730)


Signed-off-by: default avatarStefano Castagnetta <scastagnetta@nvidia.com>
(cherry picked from commit 6183cae1)
parent 3a30a1a6
...@@ -167,7 +167,7 @@ Priority is **1 = highest** (tried first). ...@@ -167,7 +167,7 @@ Priority is **1 = highest** (tried first).
| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ | | ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ |
| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | | `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A |
| `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x |
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.0 |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
......
...@@ -235,10 +235,11 @@ def _resolve_import_to_file( ...@@ -235,10 +235,11 @@ def _resolve_import_to_file(
def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None: def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None:
"""Find a compute capability from is_device_capability_family() calls in a function. """Find a compute capability from is_device_capability*() calls in a function.
Looks for the pattern: current_platform.is_device_capability_family(N) Handles two patterns:
and converts N (e.g. 100) to a CC string (e.g. "10.x"). - is_device_capability_family(N): "M.x" (e.g. 100 -> "10.x")
- is_device_capability(N): "M.m" (e.g. 100 -> "10.0")
""" """
for node in ast.walk(tree): for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != func_name: if not isinstance(node, ast.FunctionDef) or node.name != func_name:
...@@ -247,12 +248,15 @@ def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None: ...@@ -247,12 +248,15 @@ def _find_cc_in_function(tree: ast.AST, func_name: str) -> str | None:
if ( if (
isinstance(n, ast.Call) isinstance(n, ast.Call)
and isinstance(n.func, ast.Attribute) and isinstance(n.func, ast.Attribute)
and n.func.attr == "is_device_capability_family"
and n.args and n.args
and isinstance(n.args[0], ast.Constant) and isinstance(n.args[0], ast.Constant)
and isinstance(n.args[0].value, int) and isinstance(n.args[0].value, int)
): ):
return f"{n.args[0].value // 10}.x" val = n.args[0].value
if n.func.attr == "is_device_capability_family":
return f"{val // 10}.x"
elif n.func.attr == "is_device_capability":
return f"{val // 10}.{val % 10}"
return None return None
......
...@@ -289,10 +289,10 @@ def supports_trtllm_attention() -> bool: ...@@ -289,10 +289,10 @@ def supports_trtllm_attention() -> bool:
if envs.VLLM_BATCH_INVARIANT: if envs.VLLM_BATCH_INVARIANT:
return False return False
# Requires SM100 and NVIDIA artifactory to be accessible to download cubins # TRTLLM attention is currently only validated on SM100 (CC 10.0).
return ( # SM103 (GB300) hangs with FlashInfer >= 0.6.7.
current_platform.is_device_capability_family(100) and has_nvidia_artifactory() # See: https://github.com/flashinfer-ai/flashinfer/issues/2939
) return current_platform.is_device_capability(100) and has_nvidia_artifactory()
def force_use_trtllm_attention() -> bool | None: def force_use_trtllm_attention() -> bool | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment