Unverified Commit 41ba8c0b authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Add support for sharded models when TorchAO quantization is enabled (#10256)

* add sharded + device_map check
parent 31912484
...@@ -802,7 +802,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -802,7 +802,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision, revision=revision,
subfolder=subfolder or "", subfolder=subfolder or "",
) )
if hf_quantizer is not None: if hf_quantizer is not None and is_bnb_quantization_method:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False is_sharded = False
......
...@@ -278,13 +278,14 @@ class TorchAoTest(unittest.TestCase): ...@@ -278,13 +278,14 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(weight.quant_max, 15) self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
def test_offload(self): def test_device_map(self):
""" """
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
that the device map is correctly set (in the `hf_device_map` attribute of the model). The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
correctly set (in the `hf_device_map` attribute of the model).
""" """
device_map_offload = { custom_device_map_dict = {
"time_text_embed": torch_device, "time_text_embed": torch_device,
"context_embedder": torch_device, "context_embedder": torch_device,
"x_embedder": torch_device, "x_embedder": torch_device,
...@@ -293,26 +294,49 @@ class TorchAoTest(unittest.TestCase): ...@@ -293,26 +294,49 @@ class TorchAoTest(unittest.TestCase):
"norm_out": torch_device, "norm_out": torch_device,
"proj_out": "cpu", "proj_out": "cpu",
} }
device_maps = ["auto", custom_device_map_dict]
inputs = self.get_dummy_tensor_inputs(torch_device) inputs = self.get_dummy_tensor_inputs(torch_device)
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
for device_map in device_maps:
device_map_to_compare = {"": 0} if device_map == "auto" else device_map
# Test non-sharded model
with tempfile.TemporaryDirectory() as offload_folder: with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained( quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-pipe",
subfolder="transformer", subfolder="transformer",
quantization_config=quantization_config, quantization_config=quantization_config,
device_map=device_map_offload, device_map=device_map,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
offload_folder=offload_folder, offload_folder=offload_folder,
) )
self.assertTrue(quantized_model.hf_device_map == device_map_offload) self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# Test sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
output = quantized_model(**inputs)[0] output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_modules_to_not_convert(self): def test_modules_to_not_convert(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment