Unverified Commit e0bc5d62 authored by mcarilli's avatar mcarilli Committed by GitHub
Browse files

Merging in fused adam optimizer, additional DDP features tested in 18.10 (#60)

* test passes

* notes

* Using C++-side flatten and unflatten functions

* Adding csrc

* Persistent synchronization event so it doesn't need to be created and destroyed each time

* Interop with parameter flattening in SSD

* Added deterministic option to imagenet main.py

* Adding options to split gradient averaging and allreduce in pure fp32

* Fixing allreduce_maybe_retain call

* Fixing allreduce_fallback

* Also sync active_i_buckets from rank 0

* Making retain_allreduce_buffers compatible with/orthogonal to delay_allreduce=True|False

* Correcting syntax error, now all seems to work with SSD

* Optional cpp extension build

* Add mixed precision adam optimizer (#59)

* Add FusedAdam Optimizer to Apex that places all the math into a cuda kernel.

* Added fixes to fused_adam to get it to work with network.

* wip work on python interface for adam with options

* fix dispatch for halfs, add python options to handle optional half gradients and params

* cleanup, get rid of grid-stride loop
parent 81eef1ef
......@@ -3,3 +3,4 @@
from . import fp16_utils
from . import parallel
from . import amp
from . import optimizers
from .fused_adam import FusedAdam
#include <torch/torch.h>
// CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode) {
CHECK_INPUT(p)
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
CHECK_INPUT(m);
CHECK_INPUT(v);
CHECK_INPUT(g);
int64_t num_elem = p.numel();
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation.");
}
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/TensorUtils.h"
#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
typedef enum{
ADAM_MODE_0 =0, // eps under square root
ADAM_MODE_1 =1 // eps outside square root
} adamMode_t;
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode) {
//Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
for (int j = i; j < tsize; j+=totThreads) {
T scaled_grad = g[j]/grad_scale;
m[j] = b1*m[j] + (1-b1)*scaled_grad;
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
p[j] = p[j] - (step_size*m[j]/denom);
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
at::Tensor & m,
at::Tensor & v,
at::Tensor & g,
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode) {
//Get tensor size
int tsize = p.numel();
//Determine #threads and #blocks
const int threadsPerBlock = 512;
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
const float step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (g.type().scalarType() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
AT_DISPATCH_FLOATING_TYPES_AND_HALF(g.type(), "adam_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
adam_cuda_kernel<accscalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode);
}));
} else {
AT_DISPATCH_FLOATING_TYPES(g.type(), "adam_cuda_kernel", ([&] {
adam_cuda_kernel<scalar_t, scalar_t><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<scalar_t>(),
NULL, //don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t) mode);
}));
}
THCudaCheck(cudaGetLastError());
}
import torch
import fused_adam_cuda
class FusedAdam(torch.optim.Adam):
"""Implements Adam algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, eps_inside_sqrt = False):
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
super(FusedAdam, self).__init__(params, lr, betas, eps, weight_decay, amsgrad)
self.eps_mode = 0 if eps_inside_sqrt else 1
def step(self, closure=None, grads=None, output_params=None, scale=1.):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
if grads is not None:
assert len(self.param_groups)==1, "mixed precision optimizer works for a single group only"
for group in self.param_groups:
if grads is None:
grads = [None]*len(group['params'])
if output_params is None:
output_params = [None]*len(group['params'])
for p, grad, output_param in zip(group['params'],grads, output_params):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
continue
if grad is None:
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
scale,
state['step'],
self.eps_mode)
return loss
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
# from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
try:
from apex_C import flatten
from apex_C import unflatten
except ImportError:
print("Apex was built without --cpp_ext; falling back to Python flatten and unflatten")
from torch._utils import _flatten_dense_tensors as flatten
from torch._utils import _unflatten_dense_tensors as unflatten
import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
......@@ -9,7 +16,8 @@ import copy
# apply_dist_call requires that tensors in 'bucket' are all the same type.
def apply_flat_dist_call(bucket, call, extra_args=None):
coalesced = _flatten_dense_tensors(bucket)
coalesced = flatten(bucket)
if extra_args is not None:
call(coalesced, *extra_args)
......@@ -19,19 +27,30 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)
def split_half_float_double(tensors):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor"]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
# flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True
def split_by_type(tensors):
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
return buckets
# flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None):
buckets = split_by_type(tensors)
for tp in buckets:
bucket = buckets[tp]
......@@ -121,7 +140,15 @@ class DistributedDataParallel(Module):
"""
def __init__(self, module, message_size=10000000, delay_allreduce=False, shared_param=None):
def __init__(self,
module,
message_size=10000000,
delay_allreduce=False,
shared_param=None,
allreduce_trigger_params=None,
retain_allreduce_buffers=False,
allreduce_always_fp32=False,
gradient_average_split_factor=1.0):
super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around
......@@ -138,10 +165,24 @@ class DistributedDataParallel(Module):
if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
self.world_size = float(dist.get_world_size())
self.retain_allreduce_buffers = retain_allreduce_buffers
self.allreduce_always_fp32 = allreduce_always_fp32
self.gradient_average_split_factor = gradient_average_split_factor
self.custom_allreduce_triggers = False
if allreduce_trigger_params is not None:
if delay_allreduce:
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
self.custom_allreduce_triggers = True
self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])
self.delay_allreduce = delay_allreduce
self.message_size = message_size
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.module = module
......@@ -163,12 +204,14 @@ class DistributedDataParallel(Module):
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
def __getstate__(self):
attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream']
del attrs['self.reduction_event']
return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes
......@@ -177,14 +220,14 @@ class DistributedDataParallel(Module):
# Append leftover buckets
for tmp_bucket in self.tmp_buckets:
if len(tmp_bucket) > 0:
self.buckets.append(tmp_bucket)
self.active_i_buckets.append(tmp_bucket)
self.num_buckets = len(self.buckets)
self.bucket_sizes = [len(bucket) for bucket in self.buckets]
self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
list(chain(*self.buckets)))
list(chain(*self.active_i_buckets)))
dist.broadcast(info_tensor, 0)
......@@ -192,13 +235,19 @@ class DistributedDataParallel(Module):
self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])] for i in range(self.num_buckets)]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0
for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1
......@@ -216,12 +265,12 @@ class DistributedDataParallel(Module):
self.needs_refresh = False
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
self.allreduce_fallback()
flat_dist_call(grads, dist.all_reduce)
def overlapping_backward_epilogue():
torch.cuda.current_stream().wait_stream(self.reduction_stream)
self.reduction_stream.record_event(self.reduction_event)
torch.cuda.current_stream().wait_event(self.reduction_event)
# Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets:
......@@ -253,10 +302,20 @@ class DistributedDataParallel(Module):
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
self.tmp_numels[current_type] += param.numel()
ship_tmp_bucket = False
if self.custom_allreduce_triggers:
if id(param) in self.allreduce_trigger_params:
ship_tmp_bucket = True
else:
self.tmp_numels[current_type] += param.numel()
if self.tmp_numels[current_type] >= self.message_size:
self.buckets.append(self.tmp_buckets[current_type])
ship_tmp_bucket = True
# To consider: If custom_allreduce_triggers are in use, ship all
# tmp_buckets, not just tmp_buckets[current_type].
if ship_tmp_bucket:
self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
......@@ -275,6 +334,53 @@ class DistributedDataParallel(Module):
wrapper(param)
def allreduce_bucket(self, bucket):
tensor = flatten(bucket)
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.gradient_average_split_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_average_split_factor)
dist.all_reduce(tensor_to_allreduce)
tensor_to_allreduce.mul_(self.gradient_average_split_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket)
if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self):
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i)
def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
......@@ -291,9 +397,10 @@ class DistributedDataParallel(Module):
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket:
self.reduction_stream.wait_stream(torch.cuda.current_stream())
torch.cuda.current_stream().record_event(self.reduction_event)
self.reduction_stream.wait_event(self.reduction_event)
with torch.cuda.stream(self.reduction_stream):
apply_flat_dist_call(self.buckets[bucket_idx], dist.all_reduce)
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1
......@@ -306,7 +413,7 @@ class DistributedDataParallel(Module):
if i > self.next_bucket:
break
elif i == self.next_bucket:
apply_flat_dist_call(self.buckets[i], dist.all_reduce)
self.allreduce_maybe_retain(self.buckets[i], i)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
else:
......@@ -331,6 +438,7 @@ class DistributedDataParallel(Module):
self.needs_refresh = True
if self.needs_refresh:
self.active_i_buckets = []
self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
......@@ -341,6 +449,8 @@ class DistributedDataParallel(Module):
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
......
#include <torch/extension.h>
#include <torch/csrc/utils/tensor_flatten.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
at::Tensor flatten(std::vector<at::Tensor> tensors)
{
return torch::utils::flatten_dense_tensors(tensors);
}
std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
{
return torch::utils::unflatten_dense_tensors(flat, tensors);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("flatten", &flatten, "Flatten dense tensors");
m.def("unflatten", &unflatten, "Unflatten dense tensors");
}
......@@ -65,6 +65,7 @@ parser.add_argument('--static-loss-scale', type=float, default=1,
help='Static loss scale, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--prof', dest='prof', action='store_true',
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
......@@ -92,6 +93,11 @@ def fast_collate(batch):
best_prec1 = 0
args = parser.parse_args()
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
def main():
global best_prec1, args
......
......@@ -18,14 +18,32 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
cmdclass = {}
ext_modules = []
if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if "--cpp_ext" in sys.argv:
from torch.utils.cpp_extension import CppExtension
sys.argv.remove("--cpp_ext")
ext_modules.append(
CppExtension('apex_C',
['csrc/flatten_unflatten.cpp',]))
if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext")
cmdclass['build_ext'] = BuildExtension
ext_modules.append(CUDAExtension('syncbn',[
'csrc/syncbn.cpp',
'csrc/welford.cu'
]))
ext_modules.append(
CUDAExtension(name='fused_adam_cuda',
sources=['apex/optimizers/csrc/fused_adam_cuda.cpp',
'apex/optimizers/csrc/fused_adam_cuda_kernel.cu'],
extra_compile_args={'cxx': ['-O3',],
'nvcc':['--gpu-architecture=sm_70',
'-O3',
'--use_fast_math']}))
ext_modules.append(
CUDAExtension(name='syncbn',
sources=['csrc/syncbn.cpp',
'csrc/welford.cu']))
setup(
......
......@@ -33,20 +33,23 @@ class Model(Module):
def forward(self, input):
return (input*self.a)*self.b
model = DDP(Model(), message_size=1)
# model = DDP(Model(), delay_allreduce=True)
model = Model()
# model = DDP(model, message_size=1, gradient_average_split_factor=2.0)
# model = DDP(model, delay_allreduce=True)
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
x = torch.cuda.FloatTensor(4096*4096)
passed = True
torch.cuda.cudart().cudaProfilerStart()
for i in range(10):
x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
model.zero_grad()
out = model(x)
loss = out.sum()
# torch.cuda.nvtx.range_push("backward")
torch.cuda.nvtx.range_push("backward")
loss.backward()
# torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
# torch.cuda.nvtx.range_push("synchronize() + info")
# torch.cuda.synchronize()
......@@ -60,5 +63,6 @@ for i in range(10):
if not info("model.a", model.module.a, 2.): passed = False
if not info("model.b", model.module.b, 1.): passed = False
# torch.cuda.nvtx.range_pop()
torch.cuda.cudart().cudaProfilerStop()
print("passed = ", passed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment