Commit 2a915a8b authored by Xu Kai's avatar Xu Kai Committed by binmakeswell
Browse files

fix format (#568)

parent 9420d3ae
from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention from .cuda_native import LayerNorm, FusedScaleMaskSoftmax, MultiHeadAttention
__all__ = [ __all__ = ["LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"]
"LayerNorm", "FusedScaleMaskSoftmax", "MultiHeadAttention"
]
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu // modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
#include <torch/extension.h> #include <torch/extension.h>
void multi_tensor_scale_cuda( void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_sgd_cuda( void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float wd, float wd, float momentum, float dampening, float lr,
float momentum, bool nesterov, bool first_run,
float dampening, bool wd_after_momentum, float scale);
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_adam_cuda( void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float lr, const float beta1,
const float beta1, const float beta2, const float epsilon,
const float beta2, const int step, const int mode,
const float epsilon,
const int step,
const int mode,
const int bias_correction, const int bias_correction,
const float weight_decay); const float weight_decay);
void multi_tensor_lamb_cuda( void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag,
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr, const float lr, const float beta1,
const float beta1, const float beta2, const float epsilon,
const float beta2, const int step, const int bias_correction,
const float epsilon, const float weight_decay, const int grad_averaging,
const int step, const int mode, at::Tensor global_grad_norm,
const int bias_correction,
const float weight_decay,
const int grad_averaging,
const int mode,
at::Tensor global_grad_norm,
const float max_grad_norm, const float max_grad_norm,
at::optional<bool> use_nvlamb_python); at::optional<bool> use_nvlamb_python);
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor>
int chunk_size, multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python); at::optional<bool> per_tensor_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
{
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment