Unverified Commit 6669d127 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: add DeepGEMM build warning (#5176)


Co-authored-by: default avatargrimoire <streetyao@live.com>
parent f2b70afd
...@@ -16,6 +16,7 @@ import functools ...@@ -16,6 +16,7 @@ import functools
import json import json
import logging import logging
import os import os
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
...@@ -59,6 +60,9 @@ if supports_custom_op(): ...@@ -59,6 +60,9 @@ if supports_custom_op():
Bs: torch.Tensor, Bs: torch.Tensor,
C: torch.Tensor, C: torch.Tensor,
) -> None: ) -> None:
M, K = A.shape
N, _ = B.shape
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
def deep_gemm_fp8_fp8_bf16_nt_fake( def deep_gemm_fp8_fp8_bf16_nt_fake(
...@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs( ...@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs(
return None return None
@contextmanager
def _log_jit_build(M: int, N: int, K: int):
from deep_gemm.jit.runtime import RuntimeCache
origin_func = RuntimeCache.__getitem__
def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
if ret is None:
logger.warning(
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
)
return ret
RuntimeCache.__getitem__ = __patched_func
yield
RuntimeCache.__getitem__ = origin_func
def w8a8_block_fp8_matmul( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
...@@ -782,6 +805,7 @@ def w8a8_block_fp8_matmul( ...@@ -782,6 +805,7 @@ def w8a8_block_fp8_matmul(
if supports_custom_op(): if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else: else:
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else: else:
kernel = ( kernel = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment