Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e69a92a1
Unverified
Commit
e69a92a1
authored
Jul 22, 2025
by
Wentao Ye
Committed by
GitHub
Jul 21, 2025
Browse files
[Bug] DeepGemm: Fix Cuda Init Error (#21312)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
8425f785
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
22 deletions
+32
-22
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+32
-22
No files found.
vllm/utils/deep_gemm.py
View file @
e69a92a1
...
...
@@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
return
None
if
not
has_deep_gemm
():
_fp8_gemm_nt_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_masked_impl
:
Callable
[...,
Any
]
|
None
=
None
_per_block_cast_impl
:
Callable
[...,
Any
]
|
None
=
None
else
:
_dg
=
importlib
.
import_module
(
"deep_gemm"
)
# type: ignore
_fp8_gemm_nt_impl
=
_resolve_symbol
(
_dg
,
"fp8_gemm_nt"
,
"gemm_fp8_fp8_bf16_nt"
,
)
_fp8_gemm_nt_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_impl
:
Callable
[...,
Any
]
|
None
=
None
_grouped_masked_impl
:
Callable
[...,
Any
]
|
None
=
None
_per_block_cast_impl
:
Callable
[...,
Any
]
|
None
=
None
def
_lazy_init
()
->
None
:
"""Import deep_gemm and resolve symbols on first use."""
global
_fp8_gemm_nt_impl
,
_grouped_impl
,
_grouped_masked_impl
,
\
_per_block_cast_impl
# fast path
if
(
_fp8_gemm_nt_impl
is
not
None
or
_grouped_impl
is
not
None
or
_grouped_masked_impl
is
not
None
or
_per_block_cast_impl
is
not
None
):
return
if
not
has_deep_gemm
():
return
_dg
=
importlib
.
import_module
(
"deep_gemm"
)
_fp8_gemm_nt_impl
=
_resolve_symbol
(
_dg
,
"fp8_gemm_nt"
,
"gemm_fp8_fp8_bf16_nt"
)
_grouped_impl
=
_resolve_symbol
(
_dg
,
"m_grouped_fp8_gemm_nt_contiguous"
,
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous"
,
)
_dg
,
"m_grouped_fp8_gemm_nt_contiguous"
,
"m_grouped_gemm_fp8_fp8_bf16_nt_contiguous"
)
_grouped_masked_impl
=
_resolve_symbol
(
_dg
,
"fp8_m_grouped_gemm_nt_masked"
,
"m_grouped_gemm_fp8_fp8_bf16_nt_masked"
,
)
_dg
,
"fp8_m_grouped_gemm_nt_masked"
,
"m_grouped_gemm_fp8_fp8_bf16_nt_masked"
)
# Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
try
:
_math_mod
=
importlib
.
import_module
(
...
...
@@ -80,24 +86,28 @@ else:
def
fp8_gemm_nt
(
*
args
,
**
kwargs
):
_lazy_init
()
if
_fp8_gemm_nt_impl
is
None
:
return
_missing
(
*
args
,
**
kwargs
)
return
_fp8_gemm_nt_impl
(
*
args
,
**
kwargs
)
def
m_grouped_fp8_gemm_nt_contiguous
(
*
args
,
**
kwargs
):
_lazy_init
()
if
_grouped_impl
is
None
:
return
_missing
(
*
args
,
**
kwargs
)
return
_grouped_impl
(
*
args
,
**
kwargs
)
def
fp8_m_grouped_gemm_nt_masked
(
*
args
,
**
kwargs
):
_lazy_init
()
if
_grouped_masked_impl
is
None
:
return
_missing
(
*
args
,
**
kwargs
)
return
_grouped_masked_impl
(
*
args
,
**
kwargs
)
def
per_block_cast_to_fp8
(
x
,
*
args
,
**
kwargs
):
_lazy_init
()
if
_per_block_cast_impl
is
not
None
and
is_blackwell_deep_gemm_used
():
return
_per_block_cast_impl
(
x
,
use_ue8m0
=
True
)
# TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment