Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
b9f0995b
"include/vscode:/vscode.git/clone" did not exist on "56215723841565a03ab6f561cd2808262523a9ed"
Commit
b9f0995b
authored
Aug 20, 2019
by
Deyu Fu
Browse files
add back lamb stage1/2 to amp_C python
parent
f855f856
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
0 deletions
+24
-0
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+24
-0
No files found.
csrc/amp_C_frontend.cpp
View file @
b9f0995b
...
...
@@ -33,6 +33,26 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
void
multi_tensor_lamb_stage1_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_decay
,
const
int
step
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
global_grad_norm
,
const
float
max_global_grad_norm
);
void
multi_tensor_lamb_stage2_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
step_size
);
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
...
...
@@ -86,6 +106,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_lamb_stage1_cuda"
,
&
multi_tensor_lamb_stage1_cuda
,
"Computes update part of LAMB optimizer"
);
m
.
def
(
"multi_tensor_lamb_stage2_cuda"
,
&
multi_tensor_lamb_stage2_cuda
,
"Completes application of gradient to parameters for LAMB optimizer"
);
m
.
def
(
"multi_tensor_adam"
,
&
multi_tensor_adam_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_novograd"
,
&
multi_tensor_novograd_cuda
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment