"git@developer.sourcefind.cn:OpenDAS/deepspeed.git" did not exist on "1496247a19d4bfd8346001ac491b1d160b5ecbee"
Unverified Commit 63da7f56 authored by marload's avatar marload Committed by GitHub
Browse files

refactoring: Deduplication (#185)

parent 5fb22a05
...@@ -324,21 +324,21 @@ class DeepSpeedConfig(object): ...@@ -324,21 +324,21 @@ class DeepSpeedConfig(object):
elif train_batch is not None and \ elif train_batch is not None and \
micro_batch is not None: micro_batch is not None:
grad_acc = train_batch // micro_batch grad_acc = train_batch // micro_batch
grad_acc = grad_acc // self.world_size grad_acc //= self.world_size
self.gradient_accumulation_steps = grad_acc self.gradient_accumulation_steps = grad_acc
#micro_batch_per_gpu needs to be set #micro_batch_per_gpu needs to be set
elif train_batch is not None and \ elif train_batch is not None and \
grad_acc is not None: grad_acc is not None:
micro_batch = train_batch // self.world_size micro_batch = train_batch // self.world_size
micro_batch = micro_batch // grad_acc micro_batch //= grad_acc
self.train_micro_batch_size_per_gpu = micro_batch self.train_micro_batch_size_per_gpu = micro_batch
#train_batch_size needs to be set #train_batch_size needs to be set
elif micro_batch is not None and \ elif micro_batch is not None and \
grad_acc is not None: grad_acc is not None:
train_batch_size = micro_batch * grad_acc train_batch_size = micro_batch * grad_acc
train_batch_size = train_batch_size * self.world_size train_batch_size *= self.world_size
self.train_batch_size = train_batch_size self.train_batch_size = train_batch_size
#gradient_accumulation_steps and micro_batch_per_gpus is set #gradient_accumulation_steps and micro_batch_per_gpus is set
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment