Unverified Commit b866cdbd authored by Dipika Sikka's avatar Dipika Sikka Committed by GitHub
Browse files

[Misc] Add assertion and helpful message for marlin24 compressed models (#11388)

parent 2e726680
......@@ -61,6 +61,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501
)
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment