Unverified Commit 633f6e80 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix DeepGemm Init Error (#21554)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent b57296bb
...@@ -366,7 +366,7 @@ def per_token_group_quant_fp8( ...@@ -366,7 +366,7 @@ def per_token_group_quant_fp8(
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False, column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None, out_q: Optional[torch.Tensor] = None,
use_ue8m0: bool = is_blackwell_deep_gemm_used(), use_ue8m0: Optional[bool] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`. """Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the It converts the tensor values into signed float8 values and returns the
...@@ -383,6 +383,10 @@ def per_token_group_quant_fp8( ...@@ -383,6 +383,10 @@ def per_token_group_quant_fp8(
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor. scaling factor.
""" """
# TODO(wentao): refactor this
# use_ue8m0 should be a global flag that could be set by user
if use_ue8m0 is None:
use_ue8m0 = is_blackwell_deep_gemm_used()
dtype = current_platform.fp8_dtype() if dtype is None else dtype dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), ( assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"the last dimension of `x` {x.shape[-1]} must be divisible "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment