Unverified Commit d7d51a7e authored by Jacob Platin's avatar Jacob Platin Committed by GitHub
Browse files

[Bugfix] Fix Qwen3.5-FP8 Weight Loading Error on TPU (#37348)


Signed-off-by: default avatarJacob Platin <jacobplatin@google.com>
parent 3c3c0842
......@@ -768,6 +768,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
# Add check to adjust the size/offset for FP8 block scales
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
......@@ -1218,6 +1225,13 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
# Add check to adjust the size/offset for FP8 block scales
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment