"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a1fad8286f86c46821f8038d86e358e9cc62d20f"
Unverified Commit aa1c5cf5 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Add warnings and remove dependency for deterministic inference (#10724)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 592caab6
...@@ -981,29 +981,36 @@ class ServerArgs: ...@@ -981,29 +981,36 @@ class ServerArgs:
def _handle_deterministic_inference(self): def _handle_deterministic_inference(self):
if self.enable_deterministic_inference: if self.enable_deterministic_inference:
import importlib # Check sampling backend
if not importlib.util.find_spec("batch_invariant_ops"):
raise ValueError(
"batch_invariant_ops is not installed. Please install it from https://github.com/thinking-machines-lab/batch_invariant_ops/."
)
# Check some settings
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
logger.warning( logger.warning(
"Sampling backend is set to pytorch for deterministic inference." "Sampling backend is set to pytorch for deterministic inference."
) )
# Check attention backend
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference."
)
# Currently, only FA3 supports radix cache. Support for other backends is in progress # Currently, only FA3 supports radix cache. Support for other backends is in progress
if self.attention_backend != "fa3": if self.attention_backend != "fa3":
self.disable_radix_cache = True self.disable_radix_cache = True
logger.warning( logger.warning(
"Currently radix cache is disabled for deterministic inference. It will be supported in the future." f"Currently radix cache is not compatible with {self.attention_backend} attention backend for deterministic inference. It will be supported in the future."
) )
if self.attention_backend not in DETERMINISTIC_ATTENTION_BACKEND_CHOICES:
# Check TP size
if self.tp_size > 1:
raise ValueError( raise ValueError(
f"Currently only {DETERMINISTIC_ATTENTION_BACKEND_CHOICES} attention backends are supported for deterministic inference." "Currently only TP size 1 is supported for deterministic inference."
) )
# Warnings on MoE models
logger.warning(
"Currently deterministic inference is only tested on dense models. Please be cautious when using it on MoE models."
)
def _handle_other_validations(self): def _handle_other_validations(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment