Unverified Commit f4b78d13 authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

[1/2] deepseek deterministic: support deterministic inference for deepseek...

[1/2] deepseek deterministic: support deterministic inference for deepseek arch models on a single GPU (#12000)
parent 4463e90d
......@@ -495,8 +495,29 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
def bmm_batch_invariant(a, b, *, out=None):
# Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
# Process each batch separately with our persistent kernel
if a.ndim == 3 and b.ndim == 3:
results = []
for i in range(a.shape[0]):
results.append(matmul_persistent(a[i], b[i]))
result = torch.stack(results, dim=0)
if out is not None:
out.copy_(result)
return out
return result
else:
raise ValueError(
f"bmm_batch_invariant expects 3D tensors, "
f"got shapes {a.shape} and {b.shape}"
)
_batch_invariant_MODE = False
_batch_invariant_LIB = None
_original_torch_bmm = None
def is_batch_invariant_mode_enabled():
......@@ -504,7 +525,7 @@ def is_batch_invariant_mode_enabled():
def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_MODE:
return
......@@ -516,12 +537,20 @@ def enable_batch_invariant_mode():
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
)
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
# Also monkeypatch torch.bmm directly as a fallback
_original_torch_bmm = torch.bmm
torch.bmm = bmm_batch_invariant
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
if _original_torch_bmm is not None:
torch.bmm = _original_torch_bmm
_original_torch_bmm = None
_batch_invariant_MODE = False
_batch_invariant_LIB = None
......
......@@ -350,7 +350,11 @@ def handle_attention_flashinfer(attn, forward_batch):
def handle_attention_fa3(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "fa3")
# when deterministic inference is enabled, use MLA
if get_global_server_args().enable_deterministic_inference:
return _dispatch_mla_subtype(attn, forward_batch)
else:
return _handle_attention_backend(attn, forward_batch, "fa3")
def handle_attention_flashmla(attn, forward_batch):
......@@ -394,6 +398,10 @@ def handle_attention_nsa(attn, forward_batch):
def handle_attention_triton(attn, forward_batch):
# when deterministic inference is enabled, use MLA
if get_global_server_args().enable_deterministic_inference:
return _dispatch_mla_subtype(attn, forward_batch)
if (
_is_extend_without_speculative(forward_batch)
and sum(forward_batch.extend_prefix_lens_cpu) == 0
......
......@@ -1532,13 +1532,30 @@ class ServerArgs:
logger.warning(
"Sampling backend is set to pytorch for deterministic inference."
)
is_deepseek_model = False
if parse_connector_type(self.model_path) != ConnectorType.INSTANCE:
try:
hf_config = self.get_hf_config()
model_arch = hf_config.architectures[0]
is_deepseek_model = model_arch in [
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"DeepseekV32ForCausalLM",
]
except Exception:
pass
# Check attention backend
if self.attention_backend is None:
# User didn't specify attention backend, fallback based on GPU architecture
if is_sm100_supported() or is_sm120_supported():
# Blackwell and newer architectures
self.attention_backend = "flashinfer"
if is_deepseek_model:
# fallback to triton for DeepSeek models because flashinfer doesn't support deterministic inference for DeepSeek models yet
self.attention_backend = "triton"
else:
# fallback to flashinfer on Blackwell for non-DeepSeek models
self.attention_backend = "flashinfer"
else:
# Hopper (SM90) and older architectures
self.attention_backend = "fa3"
......@@ -1553,8 +1570,13 @@ class ServerArgs:
f"but you explicitly specified '{self.attention_backend}'."
)
# Currently, only FA3 and Triton supports radix cache. Support for other backends is in progress
if self.attention_backend not in ["fa3", "triton"]:
if is_deepseek_model:
raise ValueError(
f"Currently only fa3 and triton attention backends are supported for deterministic inference with DeepSeek models. But you're using {self.attention_backend}."
)
# Currently, only FA3 and Triton supports radix cache. Support for other backends is in progress
self.disable_radix_cache = True
logger.warning(
f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment