Unverified Commit e69a92a1 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] DeepGemm: Fix Cuda Init Error (#21312)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 8425f785
...@@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: ...@@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
return None return None
if not has_deep_gemm(): _fp8_gemm_nt_impl: Callable[..., Any] | None = None
_fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None _per_block_cast_impl: Callable[..., Any] | None = None
_per_block_cast_impl: Callable[..., Any] | None = None
else:
_dg = importlib.import_module("deep_gemm") # type: ignore def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use."""
_fp8_gemm_nt_impl = _resolve_symbol( global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \
_dg, _per_block_cast_impl
"fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt", # fast path
) if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
or _grouped_masked_impl is not None
or _per_block_cast_impl is not None):
return
if not has_deep_gemm():
return
_dg = importlib.import_module("deep_gemm")
_fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt",
"gemm_fp8_fp8_bf16_nt")
_grouped_impl = _resolve_symbol( _grouped_impl = _resolve_symbol(
_dg, _dg, "m_grouped_fp8_gemm_nt_contiguous",
"m_grouped_fp8_gemm_nt_contiguous", "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous")
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous",
)
_grouped_masked_impl = _resolve_symbol( _grouped_masked_impl = _resolve_symbol(
_dg, _dg, "fp8_m_grouped_gemm_nt_masked",
"fp8_m_grouped_gemm_nt_masked", "m_grouped_gemm_fp8_fp8_bf16_nt_masked")
"m_grouped_gemm_fp8_fp8_bf16_nt_masked",
)
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils. # Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try: try:
_math_mod = importlib.import_module( _math_mod = importlib.import_module(
...@@ -80,24 +86,28 @@ else: ...@@ -80,24 +86,28 @@ else:
def fp8_gemm_nt(*args, **kwargs): def fp8_gemm_nt(*args, **kwargs):
_lazy_init()
if _fp8_gemm_nt_impl is None: if _fp8_gemm_nt_impl is None:
return _missing(*args, **kwargs) return _missing(*args, **kwargs)
return _fp8_gemm_nt_impl(*args, **kwargs) return _fp8_gemm_nt_impl(*args, **kwargs)
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None: if _grouped_impl is None:
return _missing(*args, **kwargs) return _missing(*args, **kwargs)
return _grouped_impl(*args, **kwargs) return _grouped_impl(*args, **kwargs)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _grouped_masked_impl is None: if _grouped_masked_impl is None:
return _missing(*args, **kwargs) return _missing(*args, **kwargs)
return _grouped_masked_impl(*args, **kwargs) return _grouped_masked_impl(*args, **kwargs)
def per_block_cast_to_fp8(x, *args, **kwargs): def per_block_cast_to_fp8(x, *args, **kwargs):
_lazy_init()
if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used():
return _per_block_cast_impl(x, use_ue8m0=True) return _per_block_cast_impl(x, use_ue8m0=True)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment