"docs/source/git@developer.sourcefind.cn:norm/vllm.git" did not exist on "6fc2a38b110f9ba6037b31ee016f20df32426877"
Unverified Commit 41ba8c0b authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Add support for sharded models when TorchAO quantization is enabled (#10256)

* add sharded + device_map check
parent 31912484
...@@ -802,7 +802,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -802,7 +802,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision, revision=revision,
subfolder=subfolder or "", subfolder=subfolder or "",
) )
if hf_quantizer is not None: if hf_quantizer is not None and is_bnb_quantization_method:
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
is_sharded = False is_sharded = False
......
...@@ -278,13 +278,14 @@ class TorchAoTest(unittest.TestCase): ...@@ -278,13 +278,14 @@ class TorchAoTest(unittest.TestCase):
self.assertEqual(weight.quant_max, 15) self.assertEqual(weight.quant_max, 15)
self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) self.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
def test_offload(self): def test_device_map(self):
""" """
Test if the quantized model int4 weight-only is working properly with cpu/disk offload. Also verifies Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps.
that the device map is correctly set (in the `hf_device_map` attribute of the model). The custom device map performs cpu/disk offloading as well. Also verifies that the device map is
correctly set (in the `hf_device_map` attribute of the model).
""" """
device_map_offload = { custom_device_map_dict = {
"time_text_embed": torch_device, "time_text_embed": torch_device,
"context_embedder": torch_device, "context_embedder": torch_device,
"x_embedder": torch_device, "x_embedder": torch_device,
...@@ -293,27 +294,50 @@ class TorchAoTest(unittest.TestCase): ...@@ -293,27 +294,50 @@ class TorchAoTest(unittest.TestCase):
"norm_out": torch_device, "norm_out": torch_device,
"proj_out": "cpu", "proj_out": "cpu",
} }
device_maps = ["auto", custom_device_map_dict]
inputs = self.get_dummy_tensor_inputs(torch_device) inputs = self.get_dummy_tensor_inputs(torch_device)
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375])
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) for device_map in device_maps:
quantized_model = FluxTransformer2DModel.from_pretrained( device_map_to_compare = {"": 0} if device_map == "auto" else device_map
"hf-internal-testing/tiny-flux-pipe",
subfolder="transformer", # Test non-sharded model
quantization_config=quantization_config, with tempfile.TemporaryDirectory() as offload_folder:
device_map=device_map_offload, quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
torch_dtype=torch.bfloat16, quantized_model = FluxTransformer2DModel.from_pretrained(
offload_folder=offload_folder, "hf-internal-testing/tiny-flux-pipe",
) subfolder="transformer",
quantization_config=quantization_config,
self.assertTrue(quantized_model.hf_device_map == device_map_offload) device_map=device_map,
torch_dtype=torch.bfloat16,
output = quantized_model(**inputs)[0] offload_folder=offload_folder,
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() )
expected_slice = np.array([0.3457, -0.0366, 0.0105, -0.2275, -0.4941, 0.4395, -0.166, -0.6641, 0.4375]) self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
# Test sharded model
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
quantized_model = FluxTransformer2DModel.from_pretrained(
"hf-internal-testing/tiny-flux-sharded",
subfolder="transformer",
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
offload_folder=offload_folder,
)
self.assertTrue(quantized_model.hf_device_map == device_map_to_compare)
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3))
def test_modules_to_not_convert(self): def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment