Commit 2a915a8b authored by Xu Kai's avatar Xu Kai Committed by binmakeswell
Browse files

fix format (#568)

parent 9420d3ae
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
__all__ = [
"LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
]
__all__ = ["LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"]
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <torch/extension.h>
void multi_tensor_scale_cuda(
int chunk_size,
at::Tensor noop_flag,
void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
float wd, float momentum, float dampening, float lr,
bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
void multi_tensor_adam_cuda(
int chunk_size,
at::Tensor noop_flag,
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int mode,
const int bias_correction,
const float weight_decay);
void multi_tensor_lamb_cuda(
int chunk_size,
at::Tensor noop_flag,
void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
const float lr, const float beta1,
const float beta2, const float epsilon,
const int step, const int bias_correction,
const float weight_decay, const int grad_averaging,
const int mode, at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size,
at::Tensor noop_flag,
std::tuple<at::Tensor, at::Tensor>
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment