Commit 651150cb authored by Michael Carilli's avatar Michael Carilli
Browse files

cleanup

parent 843cdbe0
...@@ -103,7 +103,7 @@ class SGD(Optimizer): ...@@ -103,7 +103,7 @@ class SGD(Optimizer):
for p in params: for p in params:
param_state = self.state[p] param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have # torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that # to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel. # momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state: if 'momentum_buffer' not in param_state:
first_run = True first_run = True
...@@ -113,7 +113,7 @@ class SGD(Optimizer): ...@@ -113,7 +113,7 @@ class SGD(Optimizer):
first_run = False first_run = False
momentums.append(param_state['momentum_buffer']) momentums.append(param_state['momentum_buffer'])
# We have all parameters now, split them into appropriate groups for # We have all parameters now, split them into appropriate groups for
# parallel execution, following the 4 possible combos that the underlying # parallel execution, following the 4 possible combos that the underlying
# kernels support: # kernels support:
# grad_type, param_type, momentum_type, requires_fp16_copy # grad_type, param_type, momentum_type, requires_fp16_copy
......
...@@ -49,7 +49,7 @@ struct SGDFunctor ...@@ -49,7 +49,7 @@ struct SGDFunctor
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc]; T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size; grad_in += chunk_idx*chunk_size;
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc]; T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size; weight_in += chunk_idx*chunk_size;
......
...@@ -98,7 +98,6 @@ if "--cuda_ext" in sys.argv: ...@@ -98,7 +98,6 @@ if "--cuda_ext" in sys.argv:
'nvcc':['-maxrregcount=50', 'nvcc':['-maxrregcount=50',
'-O3', '-O3',
'--use_fast_math'] + version_ge_1_1})) '--use_fast_math'] + version_ge_1_1}))
print(ext_modules)
setup( setup(
name='apex', name='apex',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment