Unverified Commit b75b204a authored by puhuk's avatar puhuk Committed by GitHub
Browse files

Fix max_shift value in flux and related functions to 1.15 (issue #10675) (#10807)

This PR updates the max_shift value in flux to 1.15 for consistency across the codebase. In addition to modifying max_shift in flux, all related functions that copy and use this logic, such as calculate_shift in `src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py`, have also been updated to ensure uniform behavior.
parent c14057c8
...@@ -87,7 +87,7 @@ def calculate_shift( ...@@ -87,7 +87,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -878,7 +878,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -878,7 +878,7 @@ class FluxDifferentialImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -94,7 +94,7 @@ def calculate_shift( ...@@ -94,7 +94,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -823,7 +823,7 @@ class RFInversionFluxPipeline( ...@@ -823,7 +823,7 @@ class RFInversionFluxPipeline(
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
...@@ -993,7 +993,7 @@ class RFInversionFluxPipeline( ...@@ -993,7 +993,7 @@ class RFInversionFluxPipeline(
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inversion_steps = retrieve_timesteps( timesteps, num_inversion_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -91,7 +91,7 @@ def calculate_shift( ...@@ -91,7 +91,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -1041,7 +1041,7 @@ class FluxSemanticGuidancePipeline( ...@@ -1041,7 +1041,7 @@ class FluxSemanticGuidancePipeline(
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -70,7 +70,7 @@ def calculate_shift( ...@@ -70,7 +70,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -759,7 +759,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi ...@@ -759,7 +759,7 @@ class FluxCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixi
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -75,7 +75,7 @@ def calculate_shift( ...@@ -75,7 +75,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -849,7 +849,7 @@ class FluxPipeline( ...@@ -849,7 +849,7 @@ class FluxPipeline(
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -88,7 +88,7 @@ def calculate_shift( ...@@ -88,7 +88,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -802,7 +802,7 @@ class FluxControlPipeline( ...@@ -802,7 +802,7 @@ class FluxControlPipeline(
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -93,7 +93,7 @@ def calculate_shift( ...@@ -93,7 +93,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -810,7 +810,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin ...@@ -810,7 +810,7 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -119,7 +119,7 @@ def calculate_shift( ...@@ -119,7 +119,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -987,7 +987,7 @@ class FluxControlInpaintPipeline( ...@@ -987,7 +987,7 @@ class FluxControlInpaintPipeline(
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -89,7 +89,7 @@ def calculate_shift( ...@@ -89,7 +89,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -877,7 +877,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF ...@@ -877,7 +877,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -87,7 +87,7 @@ def calculate_shift( ...@@ -87,7 +87,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -865,7 +865,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -865,7 +865,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -89,7 +89,7 @@ def calculate_shift( ...@@ -89,7 +89,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -1019,7 +1019,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From ...@@ -1019,7 +1019,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -82,7 +82,7 @@ def calculate_shift( ...@@ -82,7 +82,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -884,7 +884,7 @@ class FluxFillPipeline( ...@@ -884,7 +884,7 @@ class FluxFillPipeline(
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -77,7 +77,7 @@ def calculate_shift( ...@@ -77,7 +77,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -747,7 +747,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile ...@@ -747,7 +747,7 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -74,7 +74,7 @@ def calculate_shift( ...@@ -74,7 +74,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -879,7 +879,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin): ...@@ -879,7 +879,7 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -72,7 +72,7 @@ def calculate_shift( ...@@ -72,7 +72,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -680,7 +680,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi ...@@ -680,7 +680,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -77,7 +77,7 @@ def calculate_shift( ...@@ -77,7 +77,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
...@@ -750,7 +750,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ...@@ -750,7 +750,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5), self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16), self.scheduler.config.get("max_shift", 1.15),
) )
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
......
...@@ -64,7 +64,7 @@ def calculate_shift( ...@@ -64,7 +64,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
......
...@@ -76,7 +76,7 @@ def calculate_shift( ...@@ -76,7 +76,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
......
...@@ -83,7 +83,7 @@ def calculate_shift( ...@@ -83,7 +83,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
......
...@@ -82,7 +82,7 @@ def calculate_shift( ...@@ -82,7 +82,7 @@ def calculate_shift(
base_seq_len: int = 256, base_seq_len: int = 256,
max_seq_len: int = 4096, max_seq_len: int = 4096,
base_shift: float = 0.5, base_shift: float = 0.5,
max_shift: float = 1.16, max_shift: float = 1.15,
): ):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len b = base_shift - m * base_seq_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment