Unverified Commit 2a2d3478 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix wrong gemm branch cause 250us slower (#7969)

parent aa205609
......@@ -2193,7 +2193,6 @@ class DeepseekV2ForCausalLM(nn.Module):
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
model_dtype = torch.get_default_dtype()
if w.dtype in (
torch.float8_e4m3fn,
......@@ -2219,7 +2218,6 @@ class DeepseekV2ForCausalLM(nn.Module):
_is_cuda
and weight_block_size[0] == 128
and weight_block_size[1] == 128
and model_dtype == torch.bfloat16
):
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
......@@ -2233,7 +2231,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight,
weight_scale,
weight_block_size,
model_dtype,
torch.bfloat16,
)
else:
w, scale = block_quant_to_tensor_quant(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment