Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
331683bf
Commit
331683bf
authored
Apr 02, 2022
by
shenggan
Committed by
binmakeswell
Apr 06, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu code style (#661)
parent
c336cd30
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
364 additions
and
494 deletions
+364
-494
colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
+364
-494
No files found.
colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
View file @
331683bf
...
@@ -2,23 +2,17 @@
...
@@ -2,23 +2,17 @@
* https://github.com/NVIDIA/apex
* https://github.com/NVIDIA/apex
* with minor changes. */
* with minor changes. */
#include <cuda.h>
#include <cuda_runtime.h>
#include "ATen/ATen.h"
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
#include "type_shim.h"
template
<
typename
U
>
__device__
template
<
typename
U
>
void
cuWelfordOnlineSum
(
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
U
lmean
=
mu
+
delta
/
count
;
...
@@ -27,15 +21,9 @@ void cuWelfordOnlineSum(
...
@@ -27,15 +21,9 @@ void cuWelfordOnlineSum(
sigma2
=
sigma2
+
delta
*
delta2
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
}
template
<
typename
U
>
__device__
template
<
typename
U
>
void
cuChanOnlineSum
(
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
const
U
muB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nA
=
count
;
U
nB
=
countB
;
U
nB
=
countB
;
...
@@ -44,7 +32,7 @@ void cuChanOnlineSum(
...
@@ -44,7 +32,7 @@ void cuChanOnlineSum(
if
(
nX
>
U
(
0
))
{
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
}
else
{
mu
=
U
(
0
);
mu
=
U
(
0
);
...
@@ -52,16 +40,10 @@ void cuChanOnlineSum(
...
@@ -52,16 +40,10 @@ void cuChanOnlineSum(
}
}
}
}
template
<
typename
T
,
typename
U
>
__device__
template
<
typename
T
,
typename
U
>
void
cuWelfordMuSigma2
(
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
T
*
__restrict__
vals
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
const
int
n1
,
U
*
buf
)
{
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 2) Tensor is contiguous
...
@@ -69,7 +51,7 @@ void cuWelfordMuSigma2(
...
@@ -69,7 +51,7 @@ void cuWelfordMuSigma2(
//
//
// compute variance and mean over n2
// compute variance and mean over n2
U
count
=
U
(
0
);
U
count
=
U
(
0
);
mu
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// one warp normalizes one n1 index,
...
@@ -77,46 +59,47 @@ void cuWelfordMuSigma2(
...
@@ -77,46 +59,47 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
// threadIdx.x == 0 has correct values for each warp
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
ibuf
[
wrt_y
]
=
count
;
}
}
__syncthreads
();
__syncthreads
();
// lower half merges
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -127,25 +110,19 @@ void cuWelfordMuSigma2(
...
@@ -127,25 +110,19 @@ void cuWelfordMuSigma2(
}
}
__syncthreads
();
__syncthreads
();
mu
=
ubuf
[
0
];
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
}
}
}
}
}
}
template
<
>
__device__
template
<
>
void
cuWelfordMuSigma2
(
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
const
int
n1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 2) Tensor is contiguous
...
@@ -153,7 +130,7 @@ void cuWelfordMuSigma2(
...
@@ -153,7 +130,7 @@ void cuWelfordMuSigma2(
//
//
// compute variance and mean over n2
// compute variance and mean over n2
float
count
=
0.0
f
;
float
count
=
0.0
f
;
mu
=
float
(
0
);
mu
=
float
(
0
);
sigma2
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// one warp normalizes one n1 index,
...
@@ -161,57 +138,58 @@ void cuWelfordMuSigma2(
...
@@ -161,57 +138,58 @@ void cuWelfordMuSigma2(
// initialize with standard Welford algorithm
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// 16 bit alignment
// first thread consumes first point
// first thread consumes first point
if
(
thrx
==
0
)
{
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
}
++
l
;
++
l
;
}
}
// at this point, lvals[l] are 32 bit aligned for all threads.
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
// threadIdx.x == 0 has correct values for each warp
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
ibuf
[
wrt_y
]
=
count
;
}
}
__syncthreads
();
__syncthreads
();
// lower half merges
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -222,28 +200,32 @@ void cuWelfordMuSigma2(
...
@@ -222,28 +200,32 @@ void cuWelfordMuSigma2(
}
}
__syncthreads
();
__syncthreads
();
mu
=
ubuf
[
0
];
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
}
}
}
}
}
}
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
return
U
(
1
)
/
sqrt
(
v
);
}
}
template
<
>
float
rsqrt
(
float
v
)
{
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
return
rsqrtf
(
v
);
}
}
template
<
>
double
rsqrt
(
double
v
)
{
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
return
rsqrt
(
v
);
}
}
namespace
{
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// This is the un-specialized struct. Note that we prevent instantiation of
// struct by putting an undefined symbol in the function body so it won't compile.
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// template <typename T>
// struct SharedMemory
// struct SharedMemory
// {
// {
...
@@ -260,51 +242,43 @@ template <typename T>
...
@@ -260,51 +242,43 @@ template <typename T>
struct
SharedMemory
;
struct
SharedMemory
;
template
<
>
template
<
>
struct
SharedMemory
<
float
>
struct
SharedMemory
<
float
>
{
{
__device__
float
*
getPointer
()
{
__device__
float
*
getPointer
()
extern
__shared__
float
s_float
[];
{
return
s_float
;
extern
__shared__
float
s_float
[];
}
return
s_float
;
}
};
};
}
}
// namespace
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
void
cuApplyLayerNorm
(
__global__
void
cuApplyLayerNorm
(
V
*
__restrict__
output_vals
,
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
U
*
__restrict__
mean
,
const
T
*
__restrict__
vals
,
const
int
n1
,
U
*
__restrict__
invvar
,
const
int
n2
,
const
U
epsilon
,
const
T
*
__restrict__
vals
,
const
V
*
__restrict__
gamma
,
const
int
n1
,
const
V
*
__restrict__
beta
)
{
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
)
{
// Assumptions:
// Assumptions:
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
// 2) Tensors are contiguous
//
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
}
else
{
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
}
}
...
@@ -316,252 +290,228 @@ void cuApplyLayerNorm(
...
@@ -316,252 +290,228 @@ void cuApplyLayerNorm(
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
template
<
typename
T
,
typename
U
,
typename
V
>
void
cuLoadWriteStridedInputs
(
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
thr_load_row_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
int
thr_load_col_off
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
int
i2_off
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
const
int
row_stride
,
int
i1
=
i1_block
+
thr_load_row_off
;
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
}
}
else
{
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
template
<
typename
T
,
typename
U
,
typename
V
>
void
cuLoadAddStridedInputs
(
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
thr_load_row_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
int
thr_load_col_off
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
int
i2_off
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
const
int
row_stride
,
int
i1
=
i1_block
+
thr_load_row_off
;
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
void
cuComputePartGradGammaBeta
(
__global__
void
cuComputePartGradGammaBeta
(
const
V
*
__restrict__
dout
,
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
T
*
__restrict__
input
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
const
int
n1
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
n2
,
const
int
numsegs_n1
=
const
U
*
__restrict__
mean
,
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
U
*
__restrict__
invvar
,
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
U
epsilon
,
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
U
*
part_grad_gamma
,
const
int
i1_beg_plus_one
=
U
*
part_grad_beta
)
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
{
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
thr_load_row_off
=
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
const
int
row_stride
=
blockDim
.
x
+
1
;
SharedMemory
<
U
>
shared
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y *
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
// blockDim.y + (blockDim.y -
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
// 1)*(blockDim.x/blockDim.y) elements
SharedMemory
<
U
>
shared
;
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
U
*
warp_buf1
=
(
U
*
)
buf
;
// compute partial sums from strided inputs
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// do this to increase number of loads in flight
// compute partial sums from strided inputs
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
// do this to increase number of loads in flight
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
}
__syncthreads
();
__syncthreads
();
// inter-warp reductions
}
// sum within each warp
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
U
acc1
=
U
(
0
);
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
U
acc2
=
U
(
0
);
int
row1
=
threadIdx
.
y
;
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row2
=
threadIdx
.
y
+
1
;
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
acc2
+=
warp_buf2
[
idx1
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
// inter-warp reductions
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
__syncthreads
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// sum all warps
// top half write to shared memory
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
row2
=
threadIdx
.
y
+
offset
;
sum_gamma
+=
buf
[
read_idx
];
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
}
__syncthreads
();
__syncthreads
();
}
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
if
(
threadIdx
.
y
==
0
)
{
int
row1
=
threadIdx
.
y
;
grad_gamma
[
i2
]
=
sum_gamma
;
int
row2
=
threadIdx
.
y
+
1
;
grad_beta
[
i2
]
=
sum_beta
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
void
cuComputeGradInput
(
__global__
void
cuComputeGradInput
(
const
V
*
__restrict__
dout
,
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
T
*
__restrict__
input
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
int
n1
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
int
n2
,
const
V
*
gamma
,
T
*
grad_input
)
{
const
U
*
__restrict__
mean
,
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
V
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
T
*
k_input
=
input
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
}
else
{
}
else
{
int
l
=
4
*
thrx
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss1
+=
c_loss
;
...
@@ -569,46 +519,46 @@ void cuComputeGradInput(
...
@@ -569,46 +519,46 @@ void cuComputeGradInput(
}
}
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
}
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
}
__syncthreads
();
__syncthreads
();
// lower half merges
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
}
__syncthreads
();
__syncthreads
();
}
}
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
}
}
// all threads now have the two sums over l
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
...
@@ -618,7 +568,7 @@ void cuComputeGradInput(
...
@@ -618,7 +568,7 @@ void cuComputeGradInput(
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
}
else
{
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
U
f_grad_input
=
fH
*
c_loss
;
...
@@ -631,183 +581,103 @@ void cuComputeGradInput(
...
@@ -631,183 +581,103 @@ void cuComputeGradInput(
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
V
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
,
const
V
*
beta
)
{
template
<
typename
T
,
typename
U
,
typename
V
>
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
void
HostApplyLayerNorm
(
const
dim3
threads
(
32
,
4
,
1
);
V
*
output
,
const
uint64_t
maxGridY
=
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
,
const
V
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
int
nshared
=
threads
.
y
>
1
?
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
0
;
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
void
cuda_layer_norm
(
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
Tensor
*
output
,
#ifdef VERSION_GE_1_1
at
::
Tensor
*
mean
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
invvar
,
#else
at
::
Tensor
*
input
,
at
::
IntList
normalized_shape
,
int
n1
,
#endif
int
n2
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
#ifdef VERSION_GE_1_1
using
namespace
at
;
at
::
IntArrayRef
normalized_shape
,
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
#else
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
at
::
IntList
normalized_shape
,
HostApplyLayerNorm
(
output
->
DATA_PTR
<
scalar_t_out
>
(),
#endif
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
at
::
Tensor
*
gamma
,
input
->
DATA_PTR
<
scalar_t_in
>
(),
n1
,
n2
,
epsilon
,
at
::
Tensor
*
beta
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
double
epsilon
)
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
HostApplyLayerNorm
(
output
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
->
DATA_PTR
<
scalar_t_in
>
(),
n1
,
n2
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
const
U
*
mean
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
const
V
*
gamma
,
const
V
*
beta
,
double
epsilon
,
T
*
grad_input
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
template
<
typename
T
,
typename
U
,
typename
V
>
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
void
HostLayerNormGradient
(
// compute grad_gamma(j) and grad_beta(j)
const
V
*
dout
,
const
int
part_size
=
16
;
const
U
*
mean
,
const
dim3
threads2
(
32
,
4
,
1
);
const
U
*
invvar
,
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
at
::
Tensor
*
input
,
const
int
nshared2_a
=
int
n1
,
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
int
n2
,
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
V
*
gamma
,
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
V
*
beta
,
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
double
epsilon
,
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
T
*
grad_input
,
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
V
*
grad_gamma
,
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
V
*
grad_beta
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
)
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
(),
part_size
,
part_grad_beta
.
DATA_PTR
<
U
>
(),
n1
,
n2
,
grad_gamma
,
grad_beta
);
part_size
,
}
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
// compute grad_input
const
uint64_t
maxGridY
=
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
threads1
.
y
>
1
?
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
0
;
grad_input
);
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
at
::
Tensor
*
dout
,
int
n2
,
at
::
Tensor
*
mean
,
#ifdef VERSION_GE_1_1
at
::
Tensor
*
invvar
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
input
,
#else
int
n1
,
at
::
IntList
normalized_shape
,
int
n2
,
#endif
#ifdef VERSION_GE_1_1
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
at
::
IntArrayRef
normalized_shape
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
#else
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
{
at
::
IntList
normalized_shape
,
using
namespace
at
;
#endif
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
at
::
Tensor
*
gamma
,
input
->
scalar_type
(),
gamma
->
scalar_type
(),
at
::
Tensor
*
beta
,
"cuda_layer_norm_gradient_kernel"
,
double
epsilon
,
HostLayerNormGradient
(
at
::
Tensor
*
grad_input
,
dout
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
at
::
Tensor
*
grad_gamma
,
invvar
->
DATA_PTR
<
float
>
(),
input
,
n1
,
n2
,
at
::
Tensor
*
grad_beta
)
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
{
// if gamma Tensor is NULL on input.
using
namespace
at
;
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
epsilon
,
input
->
scalar_type
(),
gamma
->
scalar_type
(),
grad_input
->
DATA_PTR
<
scalar_t_in
>
(),
"cuda_layer_norm_gradient_kernel"
,
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
HostLayerNormGradient
(
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);)
dout
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment