Commit 3d4c4cd6 authored by Herbie Bradley's avatar Herbie Bradley
Browse files

Add evaluator temp code for debug

parent be95d945
......@@ -426,33 +426,34 @@ def evaluate(
original_dtype = metrics_tensor.dtype # store original dtype
# Gather sizes
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
else:
gathered_filtered = gathered_item[gathered_item != pad_value]
# gathered_sizes = lm.accelerator.gather(num_requests)
# sizes = torch.stack(output_tensors)
# if lm.rank == 0:
# print(gathered_sizes)
# max_size = 26834
# # Use max size to pad
# metrics_tensor = metrics_tensor.to(torch.float32)
# if max_size != metrics_tensor.shape[0]:
# old_size = metrics_tensor.shape
# new_size = list(old_size)
# new_size[0] = max_size
# device_tensor = metrics_tensor.new_zeros(tuple(new_size)) + pad_value
# indices = tuple(
# slice(0, old_size[0]) if i == 0 else slice(None)
# for i in range(len(new_size))
# )
# device_tensor[indices] = metrics_tensor
# else:
# device_tensor = metrics_tensor
# gathered_item = lm.accelerator.gather(device_tensor)
# gathered_sizes = lm.accelerator.gather(num_requests)
# sizes = torch.stack(output_tensors)
# if lm.rank == 0:
# print(gathered_sizes)
# max_size = 26834
# # Use max size to pad
# metrics_tensor = metrics_tensor.to(torch.float32)
# if max_size != metrics_tensor.shape[0]:
# old_size = metrics_tensor.shape
# new_size = list(old_size)
# new_size[0] = max_size
# device_tensor = metrics_tensor.new_zeros(tuple(new_size)) + pad_value
# indices = tuple(
# slice(0, old_size[0]) if i == 0 else slice(None)
# for i in range(len(new_size))
# )
# device_tensor[indices] = metrics_tensor
# else:
# device_tensor = metrics_tensor
# gathered_item = lm.accelerator.gather(device_tensor)
gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment