Unverified Commit ee1ffe2e authored by Reza Yazdani's avatar Reza Yazdani Committed by GitHub
Browse files

CPU-Adam fix for scalar mode (#735)


Co-authored-by: default avatarJeff Rasley <jerasley@microsoft.com>
parent 1fcc5f7a
...@@ -62,6 +62,8 @@ void Adam_Optimizer::Step(float* _params, ...@@ -62,6 +62,8 @@ void Adam_Optimizer::Step(float* _params,
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for #pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) { for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4; AVX_Data grad_4;
...@@ -101,10 +103,8 @@ void Adam_Optimizer::Step(float* _params, ...@@ -101,10 +103,8 @@ void Adam_Optimizer::Step(float* _params,
SIMD_STORE(_exp_avg_sq + i, variance_4.data); SIMD_STORE(_exp_avg_sq + i, variance_4.data);
} }
if (dev_params) { if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index], launch_param_update(
dev_params + t, _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index; _buf_index = !_buf_index;
} }
} }
...@@ -112,8 +112,13 @@ void Adam_Optimizer::Step(float* _params, ...@@ -112,8 +112,13 @@ void Adam_Optimizer::Step(float* _params,
#endif #endif
if (_param_size > rounded_size) { if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for #pragma omp parallel for
for (size_t k = rounded_size; k < _param_size; k++) { for (size_t k = t; k < offset; k++) {
float grad = grads[k]; float grad = grads[k];
float param = _params[k]; float param = _params[k];
float momentum = _exp_avg[k]; float momentum = _exp_avg[k];
...@@ -131,17 +136,17 @@ void Adam_Optimizer::Step(float* _params, ...@@ -131,17 +136,17 @@ void Adam_Optimizer::Step(float* _params,
grad = momentum / grad; grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param; param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - rounded_size] = (__half)param; if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
_params[k] = param; _params[k] = param;
_exp_avg[k] = momentum; _exp_avg[k] = momentum;
_exp_avg_sq[k] = variance; _exp_avg_sq[k] = variance;
} }
if (dev_params) { if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index], launch_param_update(
dev_params + rounded_size, _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
(_param_size - rounded_size), _buf_index = !_buf_index;
Context::Instance().GetCurrentStream()); }
} }
} }
} }
...@@ -189,6 +194,7 @@ void Adam_Optimizer::Step_4(float* _params, ...@@ -189,6 +194,7 @@ void Adam_Optimizer::Step_4(float* _params,
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for #pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4]; AVX_Data grad_4[4];
...@@ -295,10 +301,8 @@ void Adam_Optimizer::Step_4(float* _params, ...@@ -295,10 +301,8 @@ void Adam_Optimizer::Step_4(float* _params,
} }
if (dev_params) { if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index], launch_param_update(
dev_params + t, _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index; _buf_index = !_buf_index;
} }
} }
...@@ -400,6 +404,7 @@ void Adam_Optimizer::Step_8(float* _params, ...@@ -400,6 +404,7 @@ void Adam_Optimizer::Step_8(float* _params,
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for #pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8]; AVX_Data grad_4[8];
...@@ -582,10 +587,8 @@ void Adam_Optimizer::Step_8(float* _params, ...@@ -582,10 +587,8 @@ void Adam_Optimizer::Step_8(float* _params,
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data); SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
} }
if (dev_params) { if (dev_params) {
launch_param_update(_doubled_buffer[_buf_index], launch_param_update(
dev_params + t, _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
copy_size,
Context::Instance().GetCurrentStream());
_buf_index = !_buf_index; _buf_index = !_buf_index;
} }
} }
...@@ -628,6 +631,7 @@ int ds_adam_step(int optimizer_id, ...@@ -628,6 +631,7 @@ int ds_adam_step(int optimizer_id,
opt->update_state(lr, epsilon, weight_decay, bias_correction); opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0)); opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
opt->SynchronizeStreams();
return 0; return 0;
} }
...@@ -664,6 +668,7 @@ int ds_adam_step_plus_copy(int optimizer_id, ...@@ -664,6 +668,7 @@ int ds_adam_step_plus_copy(int optimizer_id,
opt->Step_8( opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr); params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
opt->SynchronizeStreams();
return 0; return 0;
} }
......
...@@ -81,6 +81,8 @@ public: ...@@ -81,6 +81,8 @@ public:
return stream; return stream;
} }
cudaStream_t GetNewStream() { return at::cuda::getStreamFromPool(); }
cublasHandle_t GetCublasHandle() { return _cublasHandle; } cublasHandle_t GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc) std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
......
...@@ -65,6 +65,9 @@ public: ...@@ -65,6 +65,9 @@ public:
{ {
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
} }
~Adam_Optimizer() ~Adam_Optimizer()
{ {
...@@ -89,7 +92,10 @@ public: ...@@ -89,7 +92,10 @@ public:
float* _exp_avg_sq, float* _exp_avg_sq,
size_t _param_size, size_t _param_size,
__half* dev_params = nullptr); __half* dev_params = nullptr);
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
inline void IncrementStep(size_t step, float beta1, float beta2) inline void IncrementStep(size_t step, float beta1, float beta2)
{ {
if (beta1 != _betta1 || beta2 != _betta2) { if (beta1 != _betta1 || beta2 != _betta2) {
...@@ -152,4 +158,6 @@ private: ...@@ -152,4 +158,6 @@ private:
float* _doubled_buffer[2]; float* _doubled_buffer[2];
bool _buf_index; bool _buf_index;
bool _adamw_mode; bool _adamw_mode;
cudaStream_t _streams[2];
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment