Unverified Commit bc7a4d49 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Fix scale unscale with LoRA adapters (#5417)



* fix scale unscale v1

* final fixes + CI

* fix slow trst

* oops

* fix copies

* oops

* oops

* fix

* style

* fix copies

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8dba1808
...@@ -405,7 +405,7 @@ class StableDiffusionLDM3DPipeline( ...@@ -405,7 +405,7 @@ class StableDiffusionLDM3DPipeline(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -374,7 +374,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa ...@@ -374,7 +374,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -358,7 +358,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -358,7 +358,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -389,7 +389,7 @@ class StableDiffusionParadigmsPipeline( ...@@ -389,7 +389,7 @@ class StableDiffusionParadigmsPipeline(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -579,7 +579,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): ...@@ -579,7 +579,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -381,7 +381,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin) ...@@ -381,7 +381,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -372,7 +372,7 @@ class StableDiffusionUpscalePipeline( ...@@ -372,7 +372,7 @@ class StableDiffusionUpscalePipeline(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -479,7 +479,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -479,7 +479,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -433,7 +433,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin ...@@ -433,7 +433,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -436,12 +436,12 @@ class StableDiffusionXLPipeline( ...@@ -436,12 +436,12 @@ class StableDiffusionXLPipeline(
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None: if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
...@@ -440,12 +440,12 @@ class StableDiffusionXLImg2ImgPipeline( ...@@ -440,12 +440,12 @@ class StableDiffusionXLImg2ImgPipeline(
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None: if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
...@@ -590,12 +590,12 @@ class StableDiffusionXLInpaintPipeline( ...@@ -590,12 +590,12 @@ class StableDiffusionXLInpaintPipeline(
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None: if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
...@@ -429,7 +429,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline): ...@@ -429,7 +429,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -450,12 +450,12 @@ class StableDiffusionXLAdapterPipeline( ...@@ -450,12 +450,12 @@ class StableDiffusionXLAdapterPipeline(
if self.text_encoder is not None: if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None: if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
...@@ -361,7 +361,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -361,7 +361,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -423,7 +423,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -423,7 +423,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -556,7 +556,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -556,7 +556,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
......
...@@ -1371,7 +1371,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1371,7 +1371,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if USE_PEFT_BACKEND: if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer # remove `lora_scale` from each PEFT layer
unscale_lora_layers(self) unscale_lora_layers(self, lora_scale)
if not return_dict: if not return_dict:
return (sample,) return (sample,)
......
...@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library ...@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
""" """
import collections import collections
import importlib import importlib
from typing import Optional
from packaging import version from packaging import version
...@@ -91,21 +92,28 @@ def scale_lora_layers(model, weight): ...@@ -91,21 +92,28 @@ def scale_lora_layers(model, weight):
module.scale_layer(weight) module.scale_layer(weight)
def unscale_lora_layers(model): def unscale_lora_layers(model, weight: Optional[float] = None):
""" """
Removes the previously passed weight given to the LoRA layers of the model. Removes the previously passed weight given to the LoRA layers of the model.
Args: Args:
model (`torch.nn.Module`): model (`torch.nn.Module`):
The model to scale. The model to scale.
weight (`float`): weight (`float`, *optional*):
The weight to be given to the LoRA layers. The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be
re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct
value.
""" """
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules(): for module in model.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
module.unscale_layer() if weight is not None and weight != 0:
module.unscale_layer(weight)
elif weight is not None and weight == 0:
for adapter_name in module.active_adapters:
# if weight == 0 unscale should re-set the scale to the original value.
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True): def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
...@@ -184,7 +192,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): ...@@ -184,7 +192,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_adapter(adapter_name) module.set_adapter(adapter_name)
else: else:
module.active_adapter = adapter_name module.active_adapter = adapter_name
module.scale_layer(weight) module.set_scale(adapter_name, weight)
# set multiple active adapters # set multiple active adapters
for module in model.modules(): for module in model.modules():
......
...@@ -775,6 +775,79 @@ class PeftLoraLoaderMixinTests: ...@@ -775,6 +775,79 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Weighted adapter and mixed adapter should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
components, _, text_lora_config, unet_lora_config = self.get_dummy_components() components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
...@@ -1073,7 +1146,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1073,7 +1146,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_slice_scale = np.array([0.538, 0.539, 0.540, 0.540, 0.542, 0.539, 0.538, 0.541, 0.539]) expected_slice_scale = np.array([0.538, 0.539, 0.540, 0.540, 0.542, 0.539, 0.538, 0.541, 0.539])
predicted_slice = images[0, -3:, -3:, -1].flatten() predicted_slice = images[0, -3:, -3:, -1].flatten()
# import pdb; pdb.set_trace()
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
...@@ -1106,7 +1178,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1106,7 +1178,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
output_type="np", output_type="np",
).images ).images
predicted_slice = images[0, -3:, -3:, -1].flatten() predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.5977, 0.5985, 0.6039, 0.5976, 0.6025, 0.6036, 0.5946, 0.5979, 0.5998]) expected_slice_scale = np.array([0.5888, 0.5897, 0.5946, 0.5888, 0.5935, 0.5946, 0.5857, 0.5891, 0.5909])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
# Lora disabled # Lora disabled
...@@ -1120,7 +1192,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1120,7 +1192,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
output_type="np", output_type="np",
).images ).images
predicted_slice = images[0, -3:, -3:, -1].flatten() predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.54625, 0.5473, 0.5495, 0.5465, 0.5476, 0.5461, 0.5452, 0.5485, 0.5493]) expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment