"vscode:/vscode.git/clone" did not exist on "af539d6f0a989a3d663ef3a5f324b45d4e5571c8"
Unverified Commit 3f40070c authored by Kiyoung Kim's avatar Kiyoung Kim Committed by GitHub
Browse files

Gradient accumulation for TFTrainer (#9585)



* gradient accumulation for tftrainer

* label naming
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* label naming
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent e43f3b61
...@@ -638,7 +638,9 @@ class TFTrainer: ...@@ -638,7 +638,9 @@ class TFTrainer:
reduced_features = { reduced_features = {
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items() k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
} }
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] reduced_labels = {
k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
}
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch) self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
...@@ -650,9 +652,13 @@ class TFTrainer: ...@@ -650,9 +652,13 @@ class TFTrainer:
for k, ft in features.items() for k, ft in features.items()
} }
labels = tf.concat( labels = {
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0 k: tf.concat(
[lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
axis=0,
) )
for k, lbl in labels.items()
}
gradients = self.gradient_accumulator.gradients gradients = self.gradient_accumulator.gradients
gradients = [ gradients = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment