Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
13694642
Commit
13694642
authored
May 16, 2022
by
Ofey Chan
Committed by
binmakeswell
May 17, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp code style (#973)
parent
598cde4a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
+12
-12
No files found.
colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp
View file @
13694642
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
* https://github.com/NVIDIA/apex
* https://github.com/NVIDIA/apex
* with minor changes. */
* with minor changes. */
#include "compat.h"
#include <cassert>
#include <torch/extension.h>
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include <vector>
#include "compat.h"
namespace
{
namespace
{
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
...
@@ -65,7 +67,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
...
@@ -65,7 +67,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
normalized_shape
,
gamma
,
beta
);
check_args
(
normalized_shape
,
gamma
,
beta
);
}
}
}
// namespace
}
// namespace
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
...
@@ -73,17 +75,16 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
...
@@ -73,17 +75,16 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at
::
Tensor
*
beta
,
double
epsilon
);
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x)
\
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x)
\
#define CHECK_INPUT(x) \
CHECK_CUDA(x);
\
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
...
@@ -109,11 +110,10 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
...
@@ -109,11 +110,10 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
double
epsilon
,
at
::
Tensor
*
grad_input
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
std
::
vector
<
at
::
Tensor
>
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
invvar
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment