"sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is"
f" True and is_outer is True but received {sr_rng_state_aval.shape}"
)
else:
assertsr_rng_state_aval.shape==(4,),(
"Sharded sr_rng_state must be of shape (4,) per device when"
# We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state.
assertsr_rng_state_aval.size>=4,(
"Sharded sr_rng_state must have at least 4 elements per device when"
f" stochastic_rounding is True but received {sr_rng_state_aval.shape}"
)
...
...
@@ -552,8 +552,13 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
# We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now.