Unverified Commit 6669d127 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: add DeepGEMM build warning (#5176)


Co-authored-by: default avatargrimoire <streetyao@live.com>
parent f2b70afd
......@@ -16,6 +16,7 @@ import functools
import json
import logging
import os
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
import torch
......@@ -59,7 +60,10 @@ if supports_custom_op():
Bs: torch.Tensor,
C: torch.Tensor,
) -> None:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
M, K = A.shape
N, _ = B.shape
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
def deep_gemm_fp8_fp8_bf16_nt_fake(
A: torch.Tensor,
......@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs(
return None
@contextmanager
def _log_jit_build(M: int, N: int, K: int):
from deep_gemm.jit.runtime import RuntimeCache
origin_func = RuntimeCache.__getitem__
def __patched_func(self, *args, **kwargs):
ret = origin_func(self, *args, **kwargs)
if ret is None:
logger.warning(
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
)
return ret
RuntimeCache.__getitem__ = __patched_func
yield
RuntimeCache.__getitem__ = origin_func
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
......@@ -782,7 +805,8 @@ def w8a8_block_fp8_matmul(
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
with _log_jit_build(M, N, K):
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment