Unverified Commit 7d4d742b authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

Fixing CPU-Adam convergence issue (#503)

* fixing cpu-adam

* fixing copy with optimizer for data and model parallelism

* fixing cpu-adam

* fix cpu-adam

* fix cpu-adam
parent 4c37d705
...@@ -122,7 +122,7 @@ void Adam_Optimizer::Step(float* _params, ...@@ -122,7 +122,7 @@ void Adam_Optimizer::Step(float* _params,
float momentum = _exp_avg[k]; float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k]; float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum *= momentum * _betta1; momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum; momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2; variance = variance * _betta2;
...@@ -333,13 +333,31 @@ int create_adam_optimizer(int optimizer_id, ...@@ -333,13 +333,31 @@ int create_adam_optimizer(int optimizer_id,
#if defined(__AVX512__) #if defined(__AVX512__)
std::cout << "Adam Optimizer #" << optimizer_id std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX512 arithmetic capability." << std::endl; << " is created with AVX512 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else #else
#if defined(__AVX256__) #if defined(__AVX256__)
std::cout << "Adam Optimizer #" << optimizer_id std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX2 arithmetic capability." << std::endl; << " is created with AVX2 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else #else
std::cout << "Adam Optimizer #" << optimizer_id std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with scalar arithmetic capability." << std::endl; << " is created with scalar arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#endif #endif
#endif #endif
return 0; return 0;
...@@ -434,8 +452,6 @@ void Adam_Optimizer::Step_8(float* _params, ...@@ -434,8 +452,6 @@ void Adam_Optimizer::Step_8(float* _params,
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7); param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);
if (_weight_decay > 0 && !_adamw_mode) { if (_weight_decay > 0 && !_adamw_mode) {
AVX_Data weight_decay4;
weight_decay4.data = SIMD_SET(_weight_decay);
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data); grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data); grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data); grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
...@@ -593,6 +609,8 @@ void Adam_Optimizer::Step_8(float* _params, ...@@ -593,6 +609,8 @@ void Adam_Optimizer::Step_8(float* _params,
} }
int ds_adam_step(int optimizer_id, int ds_adam_step(int optimizer_id,
size_t step,
float lr,
torch::Tensor& params, torch::Tensor& params,
torch::Tensor& grads, torch::Tensor& grads,
torch::Tensor& exp_avg, torch::Tensor& exp_avg,
...@@ -610,13 +628,16 @@ int ds_adam_step(int optimizer_id, ...@@ -610,13 +628,16 @@ int ds_adam_step(int optimizer_id,
std::shared_ptr<Adam_Optimizer> opt = std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]); std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(); opt->IncrementStep(step);
opt->update_lr(lr);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
return 0; return 0;
} }
int ds_adam_step_plus_copy(int optimizer_id, int ds_adam_step_plus_copy(int optimizer_id,
size_t step,
float lr,
torch::Tensor& params, torch::Tensor& params,
torch::Tensor& grads, torch::Tensor& grads,
torch::Tensor& exp_avg, torch::Tensor& exp_avg,
...@@ -637,7 +658,8 @@ int ds_adam_step_plus_copy(int optimizer_id, ...@@ -637,7 +658,8 @@ int ds_adam_step_plus_copy(int optimizer_id,
std::shared_ptr<Adam_Optimizer> opt = std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]); std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(); opt->IncrementStep(step);
opt->update_lr(lr);
opt->Step_8( opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr); params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
......
...@@ -59,6 +59,7 @@ public: ...@@ -59,6 +59,7 @@ public:
_weight_decay(weight_decay), _weight_decay(weight_decay),
_betta1_t(1.0), _betta1_t(1.0),
_betta2_t(1.0), _betta2_t(1.0),
_step(0),
_buf_index(false), _buf_index(false),
_adamw_mode(adamw_mode) _adamw_mode(adamw_mode)
{ {
...@@ -88,11 +89,18 @@ public: ...@@ -88,11 +89,18 @@ public:
float* _exp_avg_sq, float* _exp_avg_sq,
size_t _param_size, size_t _param_size,
__half* dev_params = nullptr); __half* dev_params = nullptr);
inline void IncrementStep() inline void IncrementStep(size_t step)
{ {
_betta1_t *= _betta1; if (_step < step) {
_betta2_t *= _betta2; _step++;
if (_step != step) {
throw std::runtime_error("Optimizer lost track of step count!\n");
}
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
} }
inline void update_lr(float lr) { _alpha = lr; }
private: private:
#if defined(__AVX512__) or defined(__AVX256__) #if defined(__AVX512__) or defined(__AVX256__)
...@@ -114,6 +122,7 @@ private: ...@@ -114,6 +122,7 @@ private:
float _betta1_t; float _betta1_t;
float _betta2_t; float _betta2_t;
size_t _step;
float* _doubled_buffer[2]; float* _doubled_buffer[2];
bool _buf_index; bool _buf_index;
......
...@@ -95,32 +95,38 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer): ...@@ -95,32 +95,38 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad.data
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
print(f'group {group_id} param {param_id} = {p.numel()}') print(f'group {group_id} param {param_id} = {p.numel()}')
state['step'] = 0 state['step'] = 0
# gradient momentums # gradient momentums
state['exp_avg'] = torch.zeros_like(p.data, device='cpu') state['exp_avg'] = torch.zeros_like(
p.data,
memory_format=torch.preserve_format)
# gradient variances # gradient variances
state['exp_avg_sq'] = torch.zeros_like(p.data, device='cpu') state['exp_avg_sq'] = torch.zeros_like(
p.data,
memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1 state['step'] += 1
if fp16_param_groups is not None: if fp16_param_groups is not None:
p_fp16 = fp16_param_groups[group_id][param_id] ds_opt_adam.adam_update_copy(
ds_opt_adam.adam_update_copy(self.opt_id, self.opt_id,
p.data, state['step'],
grad, group['lr'],
exp_avg, p.data,
exp_avg_sq, p.grad.data,
p_fp16) state['exp_avg'],
state['exp_avg_sq'],
fp16_param_groups[group_id][param_id].data)
else: else:
ds_opt_adam.adam_update(self.opt_id, ds_opt_adam.adam_update(self.opt_id,
state['step'],
group['lr'],
p.data, p.data,
grad, p.grad.data,
exp_avg, state['exp_avg'],
exp_avg_sq) state['exp_avg_sq'])
return loss return loss
...@@ -1416,8 +1416,11 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1416,8 +1416,11 @@ class FP16_DeepSpeedZeroOptimizer(object):
if self.deepspeed_adam_offload: if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam: if type(self.optimizer) == DeepSpeedCPUAdam:
self.optimizer.step( fp16_param_groups = [
fp16_param_groups=self.parallel_partitioned_fp16_groups) fp16_partitions[partition_id]
for fp16_partitions in self.parallel_partitioned_fp16_groups
]
self.optimizer.step(fp16_param_groups=fp16_param_groups)
else: else:
self.optimizer.step() self.optimizer.step()
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups): for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment