Unverified Commit 4c7095ca authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files
parent 96ee7173
......@@ -193,7 +193,7 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
quantizer.update_quantized(master_weight.view(1, -1), shard_model_weight_fp8)
if len(amaxes) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=amaxes[0].device)
dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=amaxes[0].device)
# Reduce amaxes.
packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment