Unverified Commit b4a1ed85 authored by Cheng Lu's avatar Cheng Lu Committed by GitHub
Browse files

Add multistep DPM-Solver discrete scheduler (#1132)



* add dpmsolver discrete pytorch scheduler

* fix some typos in dpm-solver pytorch

* add dpm-solver pytorch in stable-diffusion pipeline

* add jax/flax version dpm-solver

* change code style

* change code style

* add docs

* add `add_noise` method for dpmsolver

* add pytorch unit test for dpmsolver

* add dummy object for pytorch dpmsolver

* Update src/diffusers/schedulers/scheduling_dpmsolver_discrete.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update tests/test_config.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update tests/test_config.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* resolve the code comments

* rename the file

* change class name

* fix code style

* add auto docs for dpmsolver multistep

* add more explanations for the stabilizing trick (for steps < 15)

* delete the dummy file

* change the API name of predict_epsilon, algorithm_type and solver_type

* add compatible lists
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 08a6dc8a
...@@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502). ...@@ -70,6 +70,12 @@ Original paper can be found [here](https://arxiv.org/abs/2010.02502).
[[autodoc]] DDPMScheduler [[autodoc]] DDPMScheduler
#### Multistep DPM-Solver
Original paper can be found [here](https://arxiv.org/abs/2206.00927) and the [improved version](https://arxiv.org/abs/2211.01095). The original implementation can be found [here](https://github.com/LuChengTHU/dpm-solver).
[[autodoc]] DPMSolverMultistepScheduler
#### Variance exploding, stochastic sampling from Karras et. al #### Variance exploding, stochastic sampling from Karras et. al
Original paper can be found [here](https://arxiv.org/abs/2006.11239). Original paper can be found [here](https://arxiv.org/abs/2006.11239).
......
...@@ -42,6 +42,7 @@ if is_torch_available(): ...@@ -42,6 +42,7 @@ if is_torch_available():
from .schedulers import ( from .schedulers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
IPNDMScheduler, IPNDMScheduler,
...@@ -92,6 +93,7 @@ if is_flax_available(): ...@@ -92,6 +93,7 @@ if is_flax_available():
from .schedulers import ( from .schedulers import (
FlaxDDIMScheduler, FlaxDDIMScheduler,
FlaxDDPMScheduler, FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxKarrasVeScheduler, FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler, FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler, FlaxPNDMScheduler,
......
...@@ -14,7 +14,12 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel ...@@ -14,7 +14,12 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...pipeline_flax_utils import FlaxDiffusionPipeline
from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler from ...schedulers import (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utils import logging from ...utils import logging
from . import FlaxStableDiffusionPipelineOutput from . import FlaxStableDiffusionPipelineOutput
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
...@@ -43,7 +48,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -43,7 +48,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]): scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]): safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful. Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
...@@ -57,7 +63,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ...@@ -57,7 +63,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
text_encoder: FlaxCLIPTextModel, text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel, unet: FlaxUNet2DConditionModel,
scheduler: Union[FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler], scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker, safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
dtype: jnp.dtype = jnp.float32, dtype: jnp.dtype = jnp.float32,
......
...@@ -11,6 +11,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel ...@@ -11,6 +11,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
...@@ -59,7 +60,12 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -59,7 +60,12 @@ class StableDiffusionPipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[ scheduler: Union[
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
], ],
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
......
...@@ -19,6 +19,7 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available ...@@ -19,6 +19,7 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available
if is_torch_available(): if is_torch_available():
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
from .scheduling_euler_discrete import EulerDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler from .scheduling_ipndm import IPNDMScheduler
...@@ -35,6 +36,7 @@ else: ...@@ -35,6 +36,7 @@ else:
if is_flax_available(): if is_flax_available():
from .scheduling_ddim_flax import FlaxDDIMScheduler from .scheduling_ddim_flax import FlaxDDIMScheduler
from .scheduling_ddpm_flax import FlaxDDPMScheduler from .scheduling_ddpm_flax import FlaxDDPMScheduler
from .scheduling_dpmsolver_multistep_flax import FlaxDPMSolverMultistepScheduler
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler
......
...@@ -115,6 +115,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -115,6 +115,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
"LMSDiscreteScheduler", "LMSDiscreteScheduler",
"EulerDiscreteScheduler", "EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
] ]
@register_to_config @register_to_config
......
...@@ -108,6 +108,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -108,6 +108,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
"LMSDiscreteScheduler", "LMSDiscreteScheduler",
"EulerDiscreteScheduler", "EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
] ]
@register_to_config @register_to_config
......
This diff is collapsed.
...@@ -73,6 +73,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -73,6 +73,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"LMSDiscreteScheduler", "LMSDiscreteScheduler",
"PNDMScheduler", "PNDMScheduler",
"EulerDiscreteScheduler", "EulerDiscreteScheduler",
"DPMSolverMultistepScheduler",
] ]
@register_to_config @register_to_config
......
...@@ -74,6 +74,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -74,6 +74,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"LMSDiscreteScheduler", "LMSDiscreteScheduler",
"PNDMScheduler", "PNDMScheduler",
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
] ]
@register_to_config @register_to_config
......
...@@ -73,6 +73,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -73,6 +73,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
"PNDMScheduler", "PNDMScheduler",
"EulerDiscreteScheduler", "EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
] ]
@register_to_config @register_to_config
......
...@@ -94,6 +94,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ...@@ -94,6 +94,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"LMSDiscreteScheduler", "LMSDiscreteScheduler",
"EulerDiscreteScheduler", "EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler", "EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
] ]
@register_to_config @register_to_config
......
...@@ -94,6 +94,21 @@ class FlaxDDPMScheduler(metaclass=DummyObject): ...@@ -94,6 +94,21 @@ class FlaxDDPMScheduler(metaclass=DummyObject):
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject): class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -302,6 +302,21 @@ class DDPMScheduler(metaclass=DummyObject): ...@@ -302,6 +302,21 @@ class DDPMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class DPMSolverMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class EulerAncestralDiscreteScheduler(metaclass=DummyObject): class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -19,7 +19,14 @@ import tempfile ...@@ -19,7 +19,14 @@ import tempfile
import unittest import unittest
import diffusers import diffusers
from diffusers import DDIMScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, PNDMScheduler, logging from diffusers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
PNDMScheduler,
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils.testing_utils import CaptureLogger from diffusers.utils.testing_utils import CaptureLogger
...@@ -283,3 +290,15 @@ class ConfigTester(unittest.TestCase): ...@@ -283,3 +290,15 @@ class ConfigTester(unittest.TestCase):
assert pndm.__class__ == PNDMScheduler assert pndm.__class__ == PNDMScheduler
# no warning should be thrown # no warning should be thrown
assert cap_logger.out == "" assert cap_logger.out == ""
def test_load_dpmsolver(self):
logger = logging.get_logger("diffusers.configuration_utils")
with CaptureLogger(logger) as cap_logger:
dpm = DPMSolverMultistepScheduler.from_config(
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler"
)
assert dpm.__class__ == DPMSolverMultistepScheduler
# no warning should be thrown
assert cap_logger.out == ""
...@@ -24,6 +24,7 @@ import torch.nn.functional as F ...@@ -24,6 +24,7 @@ import torch.nn.functional as F
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMScheduler, DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
IPNDMScheduler, IPNDMScheduler,
...@@ -549,6 +550,187 @@ class DDIMSchedulerTest(SchedulerCommonTest): ...@@ -549,6 +550,187 @@ class DDIMSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.1941) < 1e-3 assert abs(result_mean.item() - 0.1941) < 1e-3
class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DPMSolverMultistepScheduler,)
forward_default_kwargs = (("num_inference_steps", 25),)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"solver_order": 2,
"predict_epsilon": True,
"thresholding": False,
"sample_max_value": 1.0,
"algorithm_type": "dpmsolver++",
"solver_type": "midpoint",
"lower_order_final": False,
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output, new_output = sample, sample
for t in range(time_step, time_step + scheduler.config.solver_order + 1):
output = scheduler.step(residual, t, output, **kwargs).prev_sample
new_output = new_scheduler.step(residual, t, new_output, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals (must be after setting timesteps)
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
return sample
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.10]
scheduler.model_outputs = dummy_past_residuals[: scheduler.config.solver_order]
time_step_0 = scheduler.timesteps[5]
time_step_1 = scheduler.timesteps[6]
output_0 = scheduler.step(residual, time_step_0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, time_step_1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [25, 50, 100, 999, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for order in [1, 2, 3]:
for solver_type in ["midpoint", "heun"]:
for threshold in [0.5, 1.0, 2.0]:
for predict_epsilon in [True, False]:
self.check_over_configs(
thresholding=True,
predict_epsilon=predict_epsilon,
sample_max_value=threshold,
algorithm_type="dpmsolver++",
solver_order=order,
solver_type=solver_type,
)
def test_solver_order_and_type(self):
for algorithm_type in ["dpmsolver", "dpmsolver++"]:
for solver_type in ["midpoint", "heun"]:
for order in [1, 2, 3]:
for predict_epsilon in [True, False]:
self.check_over_configs(
solver_order=order,
solver_type=solver_type,
predict_epsilon=predict_epsilon,
algorithm_type=algorithm_type,
)
sample = self.full_loop(
solver_order=order,
solver_type=solver_type,
predict_epsilon=predict_epsilon,
algorithm_type=algorithm_type,
)
assert not torch.isnan(sample).any(), "Samples have nan numbers"
def test_lower_order_final(self):
self.check_over_configs(lower_order_final=True)
self.check_over_configs(lower_order_final=False)
def test_inference_steps(self):
for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]:
self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0)
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.3301) < 1e-3
class PNDMSchedulerTest(SchedulerCommonTest): class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (PNDMScheduler,) scheduler_classes = (PNDMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50),) forward_default_kwargs = (("num_inference_steps", 50),)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment