"tests/models/vscode:/vscode.git/clone" did not exist on "0d0d77693f79c7f7d39bba6921cc9741f00de988"
Commit 85788bae authored by Kiyoung Kim's avatar Kiyoung Kim Committed by Lysandre
Browse files

Revert "Gradient accumulation for TFTrainer (#9585)"

This reverts commit 3f40070c.
parent 82498cbc
......@@ -638,9 +638,7 @@ class TFTrainer:
reduced_features = {
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
}
reduced_labels = {
k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
}
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
......@@ -652,13 +650,9 @@ class TFTrainer:
for k, ft in features.items()
}
labels = {
k: tf.concat(
[lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
axis=0,
labels = tf.concat(
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
)
for k, lbl in labels.items()
}
gradients = self.gradient_accumulator.gradients
gradients = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment