Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
f3868524
Commit
f3868524
authored
Nov 18, 2021
by
Abhishree
Browse files
Enable Distributed FusedLAMB
parent
abb6e5ba
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
401 additions
and
234 deletions
+401
-234
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
+7
-3
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
...ntrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
+41
-23
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+353
-208
No files found.
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
View file @
f3868524
...
...
@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
grad_scale
);
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
);
void
multi_tensor_lamb_update_weights_cuda
(
int
chunk_size
,
...
...
@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
View file @
f3868524
...
...
@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor
const
MATH_T
*
per_tensor_beta2
,
const
MATH_T
*
per_tensor_beta3
,
const
int
*
per_tensor_bias_correction
,
const
int
step
,
const
int
*
step
,
const
MATH_T
*
per_tensor_epsilon
,
adamMode_t
mode
,
const
MATH_T
*
per_tensor_decay
,
const
float
grad_scale
)
const
MATH_T
*
global_scale
,
const
MATH_T
*
global_grad_norm
,
const
float
max_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
//
if(*noop_gmem == 1)
//
return;
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
combined_scale
=
*
global_scale
;
if
(
max_grad_norm
>
0
)
{
combined_scale
=
max_grad_norm
/
(
*
global_grad_norm
/
*
global_scale
+
1e-6
);
combined_scale
=
*
global_scale
/
std
::
min
((
float
)
1.0
,
combined_scale
);
}
MATH_T
beta1
=
per_tensor_beta1
[
tensor_num
];
MATH_T
beta2
=
per_tensor_beta2
[
tensor_num
];
MATH_T
beta3
=
1
-
beta1
;
MATH_T
beta1_correction
,
beta2_correction
;
if
(
per_tensor_bias_correction
[
tensor_num
]
==
1
)
{
beta1_correction
=
1
-
pow
(
beta1
,
step
);
beta2_correction
=
1
-
pow
(
beta2
,
step
);
beta1_correction
=
1
-
pow
(
beta1
,
*
step
);
beta2_correction
=
1
-
pow
(
beta2
,
*
step
);
}
else
{
beta1_correction
=
(
MATH_T
)
1.0
;
beta2_correction
=
(
MATH_T
)
1.0
;
...
...
@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
...
...
@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
...
...
@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
...
...
@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
...
...
@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor
TensorListMetadata
<
3
>&
tl
,
const
MATH_T
*
per_tensor_param_norm
,
const
MATH_T
*
per_tensor_update_norm
,
const
MATH_T
learning_rate
,
const
long
*
update_norm_offset
,
const
MATH_T
*
learning_rate
,
const
MATH_T
*
per_tensor_decay
,
const
MATH_T
*
global_grad_norm
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
//
if(*noop_gmem == 1)
//
return;
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
...
...
@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor
MATH_T
decay
=
per_tensor_decay
[
tensor_num
];
MATH_T
ratio
=
learning_rate
;
MATH_T
ratio
=
*
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
(
MATH_T
)
0.0
))
{
MATH_T
param_norm
=
per_tensor_param_norm
[
tensor_num
];
MATH_T
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
&&
param_norm
!=
0.0
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
update_norm
=
per_tensor_update_norm
[
update_norm_offset
[
tensor_num
]
]
;
ratio
=
(
update_norm
!=
0.0
&&
param_norm
!=
0.0
)
?
(
*
learning_rate
)
*
(
param_norm
/
update_norm
)
:
(
*
learning_rate
)
;
}
MATH_T
*
update
=
(
MATH_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
...
...
@@ -374,7 +384,7 @@ struct DistOptLAMBStage2Functor
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
convert
(
r_p
[
ii
],
r_p_copy
[
ii
]);
}
load_store
(
p
,
r_p
,
i_start
,
0
);
...
...
@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
grad_scale
)
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
)
{
using
namespace
at
;
...
...
@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_beta2
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_beta3
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_bias_correction
.
DATA_PTR
<
int
>
(),
step
,
step
.
DATA_PTR
<
int
>
()
,
per_tensor_epsilon
.
DATA_PTR
<
scalar_t_2
>
(),
(
adamMode_t
)
mode
,
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
grad_scale
);
)))
global_scale
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
max_grad_norm
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
}
...
...
@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
)
{
using
namespace
at
;
...
...
@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda(
DistOptLAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_param_norm
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_update_norm
.
DATA_PTR
<
scalar_t_2
>
(),
(
scalar_t_2
)
learning_rate
,
update_norm_offset
.
DATA_PTR
<
long
>
(),
learning_rate
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
use_nvlamb
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
f3868524
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment