Unverified Commit 097049b8 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Distributed Trainer: 2 little fixes (#7461)

* reset model.config

* Update src/transformers/trainer.py

* use lower case tensor

* Just tensor change
parent 0acd1ffa
...@@ -203,7 +203,7 @@ def distributed_broadcast_scalars( ...@@ -203,7 +203,7 @@ def distributed_broadcast_scalars(
) -> "torch.Tensor": ) -> "torch.Tensor":
if is_torch_available(): if is_torch_available():
try: try:
tensorized_scalar = torch.Tensor(scalars).cuda() tensorized_scalar = torch.tensor(scalars).cuda()
output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensorized_scalar) torch.distributed.all_gather(output_tensors, tensorized_scalar)
concat = torch.cat(output_tensors, dim=0) concat = torch.cat(output_tensors, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment