Unverified Commit b5b1c3da authored by user4543's avatar user4543 Committed by GitHub
Browse files

Bug - Fix bug of duration feature for model benchmarks in distributed mode. (#347)

**Description**
Fix bug of duration feature for model benchmarks in distributed mode.

**Major Revision**
- Add all_reduce to sync the result of is_finished(the function to judge whether the model benchmark should be stopped) in each step 
  - to avoid inconsistency between different ranks to determine duration end (some rank may enter one more step and can never finish)
- Add torch.cuda.synchronize() before and after step time measuring in train_step() for all model benchmarks
  - some operations in train_step() maybe async resulting incorrect step time records (for example, lstm) 
parent 6456bcad
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import os import os
from datetime import timedelta from datetime import timedelta
import time
import torch import torch
import transformers import transformers
...@@ -189,6 +190,33 @@ def _create_optimizer(self): ...@@ -189,6 +190,33 @@ def _create_optimizer(self):
return True return True
def _is_finished(self, curr_step, curr_time, check_frequency=100):
"""Judge whether the benchmarking should be stopped early or not.
Args:
curr_step (int): the current benchmarking step.
curr_time (float): the current time in seconds got from time.time().
check_frequency (int): the frequency (step numbers) to check if benchmark should be stopped.
Return:
True if the benchmarking should be stopped.
"""
is_finished = int(super()._is_finished(curr_step, curr_time))
if self._args.duration > 0:
if curr_step % check_frequency == 0:
# sync is_finished in distributed mode
# if any rank is_finished is True, all ranks should be finished
if self._args.distributed_impl == DistributedImpl.DDP:
tensor = torch.IntTensor([is_finished])
if self._args.distributed_backend == DistributedBackend.NCCL:
tensor = tensor.cuda()
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MAX)
is_finished = tensor.tolist()[0]
else:
is_finished = 0
return (is_finished == 1)
def _sync_result(self, result): def _sync_result(self, result):
"""Function to reduce the result to rank 0. """Function to reduce the result to rank 0.
...@@ -259,3 +287,16 @@ def _cal_params_count(self): ...@@ -259,3 +287,16 @@ def _cal_params_count(self):
The count of trainable parameters. The count of trainable parameters.
""" """
return sum(p.numel() for p in self._model.parameters() if p.requires_grad) return sum(p.numel() for p in self._model.parameters() if p.requires_grad)
def _timer(self):
"""Returns the current time which ensures all previous CUDA events have been finished.
If there is no GPU present, this defaults to `time.time()`; otherwise it will
synchronize CUDA before measuring the time.
Returns:
Current time in second.
"""
if self._gpu_available:
torch.cuda.synchronize()
return time.time()
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
"""Module of the Pytorch BERT model.""" """Module of the Pytorch BERT model."""
import time
import torch import torch
from transformers import BertModel, BertConfig from transformers import BertModel, BertConfig
...@@ -137,9 +135,10 @@ def _train_step(self, precision): ...@@ -137,9 +135,10 @@ def _train_step(self, precision):
""" """
duration = [] duration = []
curr_step = 0 curr_step = 0
check_frequency = 100
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._optimizer.zero_grad() self._optimizer.zero_grad()
...@@ -147,12 +146,12 @@ def _train_step(self, precision): ...@@ -147,12 +146,12 @@ def _train_step(self, precision):
loss = self._loss_fn(output, self._target) loss = self._loss_fn(output, self._target)
loss.backward() loss.backward()
self._optimizer.step() self._optimizer.step()
end = time.time() end = self._timer()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
def _inference_step(self, precision): def _inference_step(self, precision):
...@@ -171,13 +170,11 @@ def _inference_step(self, precision): ...@@ -171,13 +170,11 @@ def _inference_step(self, precision):
self._model.eval() self._model.eval()
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._model(sample) self._model(sample)
if self._gpu_available: end = self._timer()
torch.cuda.synchronize()
end = time.time()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
"""Module of the Pytorch CNN models.""" """Module of the Pytorch CNN models."""
import time
import torch import torch
from torchvision import models from torchvision import models
...@@ -99,10 +97,11 @@ def _train_step(self, precision): ...@@ -99,10 +97,11 @@ def _train_step(self, precision):
""" """
duration = [] duration = []
curr_step = 0 curr_step = 0
check_frequency = 100
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
sample = sample.to(dtype=getattr(torch, precision.value)) sample = sample.to(dtype=getattr(torch, precision.value))
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._optimizer.zero_grad() self._optimizer.zero_grad()
...@@ -110,12 +109,12 @@ def _train_step(self, precision): ...@@ -110,12 +109,12 @@ def _train_step(self, precision):
loss = self._loss_fn(output, self._target) loss = self._loss_fn(output, self._target)
loss.backward() loss.backward()
self._optimizer.step() self._optimizer.step()
end = time.time() end = self._timer()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
def _inference_step(self, precision): def _inference_step(self, precision):
...@@ -135,13 +134,11 @@ def _inference_step(self, precision): ...@@ -135,13 +134,11 @@ def _inference_step(self, precision):
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
sample = sample.to(dtype=getattr(torch, precision.value)) sample = sample.to(dtype=getattr(torch, precision.value))
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._model(sample) self._model(sample)
if self._gpu_available: end = self._timer()
torch.cuda.synchronize()
end = time.time()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
"""Module of the Pytorch GPT2 model.""" """Module of the Pytorch GPT2 model."""
import time
import torch import torch
from transformers import GPT2Model, GPT2Config from transformers import GPT2Model, GPT2Config
...@@ -131,9 +129,10 @@ def _train_step(self, precision): ...@@ -131,9 +129,10 @@ def _train_step(self, precision):
""" """
duration = [] duration = []
curr_step = 0 curr_step = 0
check_frequency = 100
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._optimizer.zero_grad() self._optimizer.zero_grad()
...@@ -141,12 +140,12 @@ def _train_step(self, precision): ...@@ -141,12 +140,12 @@ def _train_step(self, precision):
loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target) loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target)
loss.backward() loss.backward()
self._optimizer.step() self._optimizer.step()
end = time.time() end = self._timer()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
def _inference_step(self, precision): def _inference_step(self, precision):
...@@ -165,13 +164,11 @@ def _inference_step(self, precision): ...@@ -165,13 +164,11 @@ def _inference_step(self, precision):
self._model.eval() self._model.eval()
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._model(sample) self._model(sample)
if self._gpu_available: end = self._timer()
torch.cuda.synchronize()
end = time.time()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
"""Module of the Pytorch LSTM model.""" """Module of the Pytorch LSTM model."""
import time
import torch import torch
from superbench.common.utils import logger from superbench.common.utils import logger
...@@ -139,10 +137,11 @@ def _train_step(self, precision): ...@@ -139,10 +137,11 @@ def _train_step(self, precision):
""" """
duration = [] duration = []
curr_step = 0 curr_step = 0
check_frequency = 100
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
sample = sample.to(dtype=getattr(torch, precision.value)) sample = sample.to(dtype=getattr(torch, precision.value))
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._optimizer.zero_grad() self._optimizer.zero_grad()
...@@ -150,12 +149,12 @@ def _train_step(self, precision): ...@@ -150,12 +149,12 @@ def _train_step(self, precision):
loss = self._loss_fn(output, self._target) loss = self._loss_fn(output, self._target)
loss.backward() loss.backward()
self._optimizer.step() self._optimizer.step()
end = time.time() end = self._timer()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
duration.append((end - start) * 1000) duration.append((end - start) * 1000)
if self._is_finished(curr_step, end): if self._is_finished(curr_step, end, check_frequency):
return duration return duration
def _inference_step(self, precision): def _inference_step(self, precision):
...@@ -175,13 +174,11 @@ def _inference_step(self, precision): ...@@ -175,13 +174,11 @@ def _inference_step(self, precision):
while True: while True:
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
sample = sample.to(dtype=getattr(torch, precision.value)) sample = sample.to(dtype=getattr(torch, precision.value))
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._model(sample) self._model(sample)
if self._gpu_available: end = self._timer()
torch.cuda.synchronize()
end = time.time()
curr_step += 1 curr_step += 1
if curr_step > self._args.num_warmup: if curr_step > self._args.num_warmup:
# Save the step time of every training/inference step, unit is millisecond. # Save the step time of every training/inference step, unit is millisecond.
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
"""Tests for BenchmarkRegistry module.""" """Tests for BenchmarkRegistry module."""
import time
import numbers import numbers
import torch import torch
...@@ -118,7 +117,7 @@ def _train_step(self, precision): ...@@ -118,7 +117,7 @@ def _train_step(self, precision):
duration = [] duration = []
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
sample = sample.to(dtype=getattr(torch, precision.value)) sample = sample.to(dtype=getattr(torch, precision.value))
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._optimizer.zero_grad() self._optimizer.zero_grad()
...@@ -126,7 +125,7 @@ def _train_step(self, precision): ...@@ -126,7 +125,7 @@ def _train_step(self, precision):
loss = self._loss_fn(output, self._target) loss = self._loss_fn(output, self._target)
loss.backward() loss.backward()
self._optimizer.step() self._optimizer.step()
end = time.time() end = self._timer()
if idx % 10 == 0: if idx % 10 == 0:
logger.info( logger.info(
'Train step [{}/{} ({:.0f}%)]'.format( 'Train step [{}/{} ({:.0f}%)]'.format(
...@@ -153,13 +152,13 @@ def _inference_step(self, precision): ...@@ -153,13 +152,13 @@ def _inference_step(self, precision):
self._model.eval() self._model.eval()
for idx, sample in enumerate(self._dataloader): for idx, sample in enumerate(self._dataloader):
sample = sample.to(dtype=getattr(torch, precision.value)) sample = sample.to(dtype=getattr(torch, precision.value))
start = time.time() start = self._timer()
if self._gpu_available: if self._gpu_available:
sample = sample.cuda() sample = sample.cuda()
self._model(sample) self._model(sample)
if self._gpu_available: if self._gpu_available:
torch.cuda.synchronize() torch.cuda.synchronize()
end = time.time() end = self._timer()
if idx % 10 == 0: if idx % 10 == 0:
logger.info( logger.info(
'Inference step [{}/{} ({:.0f}%)]'.format( 'Inference step [{}/{} ({:.0f}%)]'.format(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment