Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
LLama_fastertransformer
Commits
0211193c
Commit
0211193c
authored
Aug 17, 2023
by
zhuwenwen
Browse files
initial llama
parents
Pipeline
#509
failed with stages
in 0 seconds
Changes
260
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4399 additions
and
0 deletions
+4399
-0
3rdparty/Megatron-LM/megatron/fused_kernels/layer_norm_cuda_kernel.cu
...atron-LM/megatron/fused_kernels/layer_norm_cuda_kernel.cu
+832
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_masked_softmax.cpp
...atron-LM/megatron/fused_kernels/scaled_masked_softmax.cpp
+97
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_masked_softmax.h
...egatron-LM/megatron/fused_kernels/scaled_masked_softmax.h
+717
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_masked_softmax_cuda.cu
...n-LM/megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+117
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_softmax.cpp
...rty/Megatron-LM/megatron/fused_kernels/scaled_softmax.cpp
+75
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_softmax_cuda.cu
...Megatron-LM/megatron/fused_kernels/scaled_softmax_cuda.cu
+104
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+72
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
...gatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+513
-0
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+98
-0
3rdparty/Megatron-LM/megatron/fused_kernels/tests/__init__.py
...arty/Megatron-LM/megatron/fused_kernels/tests/__init__.py
+0
-0
3rdparty/Megatron-LM/megatron/fused_kernels/tests/test_fused_kernels.py
...ron-LM/megatron/fused_kernels/tests/test_fused_kernels.py
+300
-0
3rdparty/Megatron-LM/megatron/fused_kernels/type_shim.h
3rdparty/Megatron-LM/megatron/fused_kernels/type_shim.h
+91
-0
3rdparty/Megatron-LM/megatron/global_vars.py
3rdparty/Megatron-LM/megatron/global_vars.py
+275
-0
3rdparty/Megatron-LM/megatron/indexer.py
3rdparty/Megatron-LM/megatron/indexer.py
+129
-0
3rdparty/Megatron-LM/megatron/initialize.py
3rdparty/Megatron-LM/megatron/initialize.py
+253
-0
3rdparty/Megatron-LM/megatron/learning_rates.py
3rdparty/Megatron-LM/megatron/learning_rates.py
+164
-0
3rdparty/Megatron-LM/megatron/memory.py
3rdparty/Megatron-LM/megatron/memory.py
+145
-0
3rdparty/Megatron-LM/megatron/microbatches.py
3rdparty/Megatron-LM/megatron/microbatches.py
+157
-0
3rdparty/Megatron-LM/megatron/model/__init__.py
3rdparty/Megatron-LM/megatron/model/__init__.py
+24
-0
3rdparty/Megatron-LM/megatron/model/bert_model.py
3rdparty/Megatron-LM/megatron/model/bert_model.py
+236
-0
No files found.
Too many changes to show.
To preserve performance only
260 of 260+
files are displayed.
Plain diff
Email patch
3rdparty/Megatron-LM/megatron/fused_kernels/layer_norm_cuda_kernel.cu
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
mu
=
lmean
;
U
delta2
=
curr
-
lmean
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
template
<
typename
U
>
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nB
=
countB
;
count
=
count
+
countB
;
U
nX
=
count
;
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
mu
=
U
(
0
);
sigma2
=
U
(
0
);
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U
count
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
}
}
}
template
<
>
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float
count
=
0.0
f
;
mu
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// first thread consumes first point
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
++
l
;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
}
}
}
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuApplyLayerNorm
(
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
}
__syncthreads
();
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputePartGradGammaBeta
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputeGradInput
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
V
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
// prevent race where buf is written again before reads are done
__syncthreads
();
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
V
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
,
const
V
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
HostApplyLayerNorm
(
output
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
->
DATA_PTR
<
scalar_t_in
>
(),
n1
,
n2
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
const
U
*
mean
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
const
V
*
gamma
,
const
V
*
beta
,
double
epsilon
,
T
*
grad_input
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
(),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
gamma
->
scalar_type
(),
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
dout
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_masked_softmax.cpp
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_masked_softmax.h
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
,
int
pad_batches
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_masked_softmax_cuda.cu
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
4096
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_softmax.cpp
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_softmax
::
fwd
,
"Self Multihead Attention scaled, softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_softmax
::
bwd
,
"Self Multihead Attention scaled, softmax -- Backward."
);
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_softmax_cuda.cu
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
4096
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_softmax_forward"
,
dispatch_scaled_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
ELEMENTS_PER_LDG_STG
*
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
((
element_index
+
element
)
<
batch_element_count
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
local_seq
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
else
{
out
[
element
]
=
0
;
}
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
else
if
(
element_index
<
element_count
)
{
copy_zero_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
element_index
+
element
<
batch_element_count
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
*
stride
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
3rdparty/Megatron-LM/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
3rdparty/Megatron-LM/megatron/fused_kernels/tests/__init__.py
0 → 100644
View file @
0211193c
3rdparty/Megatron-LM/megatron/fused_kernels/tests/test_fused_kernels.py
0 → 100644
View file @
0211193c
import
math
import
torch
from
torch.nn
import
LayerNorm
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.fused_layer_norm
import
MixedFusedLayerNorm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.utils
import
attention_mask_func
def
test_load_fused_kernels
():
try
:
import
fused_mix_prec_layer_norm_cuda
import
scaled_masked_softmax_cuda
import
scaled_upper_triang_masked_softmax_cuda
import
torch
print
(
"[Success] load_fused_kernels"
)
except
ImportError
as
e
:
print
(
"[Fail] load_fused_kernels"
)
raise
e
def
test_fused_softmax
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
embedding_output
=
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
# (bsz, 1, 1, seq_len)
mask
=
bert
.
get_extended_attention_mask
(
attention_mask
=
tokens
[
"attention_mask"
].
cuda
(),
input_shape
=
tokens
[
"input_ids"
].
shape
,
device
=
bert
.
device
,
)
# (bsz, 1, seq_len, seq_len)
mask
=
mask
.
repeat
(
1
,
1
,
mask
.
size
()[
-
1
],
1
)
attention
=
bert
.
encoder
.
layer
[
0
].
attention
.
self
key_layer
=
attention
.
transpose_for_scores
(
attention
.
key
(
embedding_output
))
query_layer
=
attention
.
transpose_for_scores
(
attention
.
query
(
embedding_output
))
attention_scores
=
torch
.
matmul
(
query_layer
,
key_layer
.
transpose
(
-
1
,
-
2
))
attention_scores
/=
math
.
sqrt
(
key_layer
.
size
()[
-
1
])
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attention_scores
,
(
mask
!=
0
),
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
padding
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attention_scores
,
(
mask
!=
0
),
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_fused_upper_triangle_mask_softmax
():
gpt
=
GPT2Model
.
from_pretrained
(
"gpt2"
).
cuda
().
half
()
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi"
# 24
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
attention_mask
=
tokens
[
"attention_mask"
].
cuda
()
attention_mask
=
attention_mask
.
view
(
attention_mask
.
size
(
0
),
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
attention_mask
.
size
()[
-
1
],
1
)
attn
=
gpt
.
h
[
0
]
hidden_states
=
gpt
.
wte
(
tokens
[
"input_ids"
].
cuda
())
q
,
k
,
v
=
attn
.
attn
.
c_attn
(
hidden_states
).
split
(
768
,
dim
=-
1
)
q
=
attn
.
attn
.
_split_heads
(
q
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
k
=
attn
.
attn
.
_split_heads
(
k
,
attn
.
attn
.
num_heads
,
attn
.
attn
.
head_dim
)
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
sq
,
sk
=
q
.
size
(
-
2
),
k
.
size
(
-
2
)
causal_mask
=
attn
.
attn
.
bias
[:,
:,
sk
-
sq
:
sk
,
:
sk
].
bool
()
total_mask
=
~
(
causal_mask
&
(
attention_mask
==
0
))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
True
,
)
.
cuda
()
.
half
()
)
fused_softmax_output
=
fused_softmax
(
attn_weights
,
total_mask
,
)
torch_softmax
=
(
FusedScaleMaskSoftmax
(
input_in_fp16
=
True
,
input_in_bf16
=
False
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
AttnMaskType
.
causal
,
scaled_masked_softmax_fusion
=
False
,
)
.
cuda
()
.
half
()
)
torch_softmax_output
=
torch_softmax
(
attn_weights
,
total_mask
,
)
test_result
=
(
fused_softmax_output
-
torch_softmax_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_fused_upper_triangle_mask_softmax"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_softmax_output
[
-
1
][
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
def
test_layer_norm
():
bert
=
BertModel
.
from_pretrained
(
"bert-base-cased"
).
cuda
().
half
()
tokenizer
=
BertTokenizer
.
from_pretrained
(
"bert-base-cased"
)
test_text
=
(
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi"
# 32
)
tokens
=
tokenizer
(
[
test_text
]
*
4
,
return_tensors
=
"pt"
,
)
# [bsz, seq_len, d_model]
embedding_output
=
(
bert
.
embeddings
(
input_ids
=
tokens
[
"input_ids"
].
cuda
(),
position_ids
=
None
,
token_type_ids
=
tokens
[
"token_type_ids"
].
cuda
(),
inputs_embeds
=
None
,
past_key_values_length
=
0
,
)
.
cuda
()
.
half
()
)
fused_layernorm_layer
=
(
MixedFusedLayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
torch_layernorm_layer
=
(
LayerNorm
(
normalized_shape
=
embedding_output
.
size
(
-
1
)).
cuda
().
half
()
)
fused_output
=
fused_layernorm_layer
(
embedding_output
)
torch_output
=
torch_layernorm_layer
(
embedding_output
)
test_result
=
(
fused_output
-
torch_output
).
abs
()
while
test_result
.
dim
()
!=
1
:
test_result
=
test_result
.
mean
(
dim
=-
1
)
diff
=
test_result
.
mean
(
dim
=-
1
)
if
diff
<=
1e-3
:
print
(
f
"
\n
[Success] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
"
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
else
:
print
(
f
"
\n
[Fail] test_layer_norm"
f
"
\n
> mean_difference=
{
diff
}
, "
f
"
\n
> fused_values=
{
fused_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
, "
f
"
\n
> torch_values=
{
torch_output
[
-
1
][
-
1
][:
5
].
tolist
()
}
"
)
if
__name__
==
"__main__"
:
try
:
from
transformers
import
BertTokenizer
,
GPT2Tokenizer
from
transformers.models.bert.modeling_bert
import
BertModel
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
import
transformers
transformers
.
logging
.
set_verbosity
(
transformers
.
logging
.
FATAL
,
)
except
:
print
(
"
\n
[Fail] Please install `transformers` package to test fused kernels
\n
"
)
exit
(
-
1
)
test_load_fused_kernels
()
test_fused_softmax
()
test_fused_upper_triangle_mask_softmax
()
test_layer_norm
()
3rdparty/Megatron-LM/megatron/fused_kernels/type_shim.h
0 → 100644
View file @
0211193c
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
3rdparty/Megatron-LM/megatron/global_vars.py
0 → 100644
View file @
0211193c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
import
os
import
sys
import
time
import
torch
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
_GLOBAL_ARGS
=
None
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
_GLOBAL_TOKENIZER
=
None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_SIGNAL_HANDLER
=
None
def
get_args
():
"""Return arguments."""
_ensure_var_is_initialized
(
_GLOBAL_ARGS
,
'args'
)
return
_GLOBAL_ARGS
def
get_num_microbatches
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
get_current_global_batch_size
():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
,
consistency_check
=
True
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
)
def
get_tokenizer
():
"""Return tokenizer."""
_ensure_var_is_initialized
(
_GLOBAL_TOKENIZER
,
'tokenizer'
)
return
_GLOBAL_TOKENIZER
def
get_tensorboard_writer
():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return
_GLOBAL_TENSORBOARD_WRITER
def
get_adlr_autoresume
():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
return
_GLOBAL_ADLR_AUTORESUME
def
get_timers
():
"""Return timers."""
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
return
_GLOBAL_TIMERS
def
get_signal_handler
():
_ensure_var_is_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
return
_GLOBAL_SIGNAL_HANDLER
def
_set_signal_handler
():
global
_GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args
=
_parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
_build_num_microbatches_calculator
(
args
)
if
args
.
vocab_file
:
_
=
_build_tokenizer
(
args
)
_set_tensorboard_writer
(
args
)
_set_adlr_autoresume
(
args
)
_set_timers
()
if
args
.
exit_signal_handler
:
_set_signal_handler
()
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
"""Parse entire arguments."""
global
_GLOBAL_ARGS
_ensure_var_is_not_initialized
(
_GLOBAL_ARGS
,
'args'
)
_GLOBAL_ARGS
=
parse_args
(
extra_args_provider
=
extra_args_provider
,
defaults
=
defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
return
_GLOBAL_ARGS
def
_build_num_microbatches_calculator
(
args
):
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_ensure_var_is_not_initialized
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
,
'num microbatches calculator'
)
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
build_num_microbatches_calculator
(
args
)
def
_build_tokenizer
(
args
):
"""Initialize tokenizer."""
global
_GLOBAL_TOKENIZER
_ensure_var_is_not_initialized
(
_GLOBAL_TOKENIZER
,
'tokenizer'
)
_GLOBAL_TOKENIZER
=
build_tokenizer
(
args
)
return
_GLOBAL_TOKENIZER
def
rebuild_tokenizer
(
args
):
global
_GLOBAL_TOKENIZER
_GLOBAL_TOKENIZER
=
None
return
_build_tokenizer
(
args
)
def
_set_tensorboard_writer
(
args
):
"""Set tensorboard writer."""
global
_GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized
(
_GLOBAL_TENSORBOARD_WRITER
,
'tensorboard writer'
)
if
hasattr
(
args
,
'tensorboard_dir'
)
and
\
args
.
tensorboard_dir
and
args
.
rank
==
(
args
.
world_size
-
1
):
try
:
from
torch.utils.tensorboard
import
SummaryWriter
print
(
'> setting tensorboard ...'
)
_GLOBAL_TENSORBOARD_WRITER
=
SummaryWriter
(
log_dir
=
args
.
tensorboard_dir
,
max_queue
=
args
.
tensorboard_queue_size
)
except
ModuleNotFoundError
:
print
(
'WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.'
,
flush
=
True
)
def
_set_adlr_autoresume
(
args
):
"""Initialize ADLR autoresume."""
global
_GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized
(
_GLOBAL_ADLR_AUTORESUME
,
'adlr autoresume'
)
if
args
.
adlr_autoresume
:
if
args
.
rank
==
0
:
print
(
'enabling autoresume ...'
,
flush
=
True
)
sys
.
path
.
append
(
os
.
environ
.
get
(
'SUBMIT_SCRIPTS'
,
'.'
))
try
:
from
userlib.auto_resume
import
AutoResume
except
BaseException
:
print
(
'ADLR autoresume is not available, exiting ...'
)
sys
.
exit
()
_GLOBAL_ADLR_AUTORESUME
=
AutoResume
def
_set_timers
():
"""Initialize timers."""
global
_GLOBAL_TIMERS
_ensure_var_is_not_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
_GLOBAL_TIMERS
=
Timers
()
def
_ensure_var_is_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
not
None
,
'{} is not initialized.'
.
format
(
name
)
def
_ensure_var_is_not_initialized
(
var
,
name
):
"""Make sure the input variable is not None."""
assert
var
is
None
,
'{} is already initialized.'
.
format
(
name
)
class
_Timer
:
"""Timer."""
def
__init__
(
self
,
name
):
self
.
name_
=
name
self
.
elapsed_
=
0.0
self
.
started_
=
False
self
.
start_time
=
time
.
time
()
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
def
reset
(
self
):
"""Reset timer."""
self
.
elapsed_
=
0.0
self
.
started_
=
False
def
elapsed
(
self
,
reset
=
True
):
"""Calculate the elapsed time."""
started_
=
self
.
started_
# If the timing in progress, end it first.
if
self
.
started_
:
self
.
stop
()
# Get the elapsed time.
elapsed_
=
self
.
elapsed_
# Reset the elapsed time
if
reset
:
self
.
reset
()
# If timing was in progress, set it back.
if
started_
:
self
.
start
()
return
elapsed_
class
Timers
:
"""Group of timers."""
def
__init__
(
self
):
self
.
timers
=
{}
def
__call__
(
self
,
name
):
if
name
not
in
self
.
timers
:
self
.
timers
[
name
]
=
_Timer
(
name
)
return
self
.
timers
[
name
]
def
write
(
self
,
names
,
writer
,
iteration
,
normalizer
=
1.0
,
reset
=
False
):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert
normalizer
>
0.0
for
name
in
names
:
value
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
/
normalizer
writer
.
add_scalar
(
name
+
'-time'
,
value
,
iteration
)
def
log
(
self
,
names
,
normalizer
=
1.0
,
reset
=
True
):
"""Log a group of timers."""
assert
normalizer
>
0.0
string
=
'time (ms)'
for
name
in
names
:
elapsed_time
=
self
.
timers
[
name
].
elapsed
(
reset
=
reset
)
*
1000.0
/
normalizer
string
+=
' | {}: {:.2f}'
.
format
(
name
,
elapsed_time
)
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
):
print
(
string
,
flush
=
True
)
else
:
print
(
string
,
flush
=
True
)
3rdparty/Megatron-LM/megatron/indexer.py
0 → 100644
View file @
0211193c
import
sys
import
time
import
torch
import
torch.distributed
as
dist
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron.checkpointing
import
load_biencoder_checkpoint
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_batch
from
megatron.data.biencoder_dataset_utils
import
get_one_epoch_dataloader
from
megatron.data.realm_index
import
detach
,
OpenRetreivalDataStore
from
megatron.model.biencoder_model
import
get_model_provider
from
megatron.training
import
get_model
class
IndexBuilder
(
object
):
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def
__init__
(
self
):
args
=
get_args
()
self
.
model
=
None
self
.
dataloader
=
None
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
self
.
load_attributes
()
self
.
is_main_builder
=
mpu
.
get_data_parallel_rank
()
==
0
self
.
num_total_builders
=
mpu
.
get_data_parallel_world_size
()
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
"""
Load the necessary attributes: model, dataloader and empty BlockData
"""
only_context_model
=
True
if
self
.
biencoder_shared_query_context_model
:
only_context_model
=
False
model
=
get_model
(
get_model_provider
(
only_context_model
=
\
only_context_model
,
biencoder_shared_query_context_model
=
\
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
self
.
batch_size
))
self
.
evidence_embedder_obj
=
OpenRetreivalDataStore
(
\
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
"""
Utility function for tracking progress
"""
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
while
True
:
try
:
# batch also has query_tokens and query_pad_data
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
context_pad_mask
=
get_open_retrieval_batch
(
\
self
.
dataloader
)
except
(
StopIteration
,
IndexError
):
break
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert
context_mask
.
dtype
==
torch
.
bool
context_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
context_types
)
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
))
# This process signals to finalize its shard and then synchronize with
# the other processes
self
.
evidence_embedder_obj
.
save_shard
()
torch
.
distributed
.
barrier
()
del
self
.
model
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
self
.
evidence_embedder_obj
.
merge_shards_and_save
()
# make sure that every single piece of data was embedded
assert
len
(
self
.
evidence_embedder_obj
.
embed_data
)
==
\
len
(
self
.
dataset
)
self
.
evidence_embedder_obj
.
clear
()
# complete building the final copy
torch
.
distributed
.
barrier
()
3rdparty/Megatron-LM/megatron/initialize.py
0 → 100644
View file @
0211193c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron initialization."""
import
random
import
os
import
time
import
numpy
as
np
import
torch
from
datetime
import
timedelta
from
megatron
import
fused_kernels
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_args
from
megatron
import
get_tensorboard_writer
from
megatron
import
mpu
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu
import
(
set_tensor_model_parallel_rank
,
set_tensor_model_parallel_world_size
)
def
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if
not
allow_no_cuda
:
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
# torch.distributed initialization
def
finish_mpu_init
():
args
=
get_args
()
# Pytorch distributed.
_initialize_distributed
()
# Random seeds for reproducibility.
if
args
.
rank
==
0
:
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
args
=
get_args
()
if
args
.
lazy_mpu_init
:
args
.
use_cpu_initialization
=
True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
set_tensor_model_parallel_world_size
(
args
.
tensor_model_parallel_size
)
# and return function for external DDP manager
# to call when it has DDP initialized
set_tensor_model_parallel_rank
(
args
.
rank
)
return
finish_mpu_init
else
:
# Megatron's MPU is the master. Complete initialization right away.
finish_mpu_init
()
# Autoresume.
_init_autoresume
()
# Compile dependencies.
_compile_dependencies
()
# No continuation function
return
None
def
_compile_dependencies
():
args
=
get_args
()
# =========================
# Compile dataset C++ code.
# =========================
# TODO: move this to ninja
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling dataset index builder ...'
)
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
print
(
'>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
# ==================
# Load fused kernels
# ==================
# Custom kernel constraints check.
seq_len
=
args
.
seq_length
attn_batch_size
=
\
(
args
.
num_attention_heads
/
args
.
tensor_model_parallel_size
)
*
\
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
4096
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
custom_kernel_constraint
and
args
.
masked_softmax_fusion
):
if
args
.
rank
==
0
:
print
(
'WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.'
,
flush
=
True
)
# Always build on rank zero first.
if
torch
.
distributed
.
get_rank
()
==
0
:
start_time
=
time
.
time
()
print
(
'> compiling and loading fused kernels ...'
,
flush
=
True
)
fused_kernels
.
load
(
args
)
torch
.
distributed
.
barrier
()
else
:
torch
.
distributed
.
barrier
()
fused_kernels
.
load
(
args
)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'
.
format
(
time
.
time
()
-
start_time
),
flush
=
True
)
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
if
torch
.
distributed
.
is_initialized
():
if
args
.
rank
==
0
:
print
(
'torch distributed is already initialized, '
'skipping initialization ...'
,
flush
=
True
)
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
else
:
if
args
.
rank
==
0
:
print
(
'> initializing torch distributed ...'
,
flush
=
True
)
# Manually set the device ids.
if
device_count
>
0
:
device
=
args
.
rank
%
device_count
if
args
.
local_rank
is
not
None
:
assert
args
.
local_rank
==
device
,
\
'expected local-rank to be the same as rank % device-count.'
else
:
args
.
local_rank
=
device
torch
.
cuda
.
set_device
(
device
)
# Call the init process
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
minutes
=
10
))
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if
device_count
>
0
:
if
mpu
.
model_parallel_is_initialized
():
print
(
'model parallel is already initialized'
)
else
:
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
def
_init_autoresume
():
"""Set autoresume start time."""
autoresume
=
get_adlr_autoresume
()
if
autoresume
:
torch
.
distributed
.
barrier
()
autoresume
.
init
()
torch
.
distributed
.
barrier
()
def
_set_random_seed
(
seed_
,
data_parallel_random_init
=
False
):
"""Set random seed for reproducability."""
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
# Ensure different data parallel ranks get different seeds
if
data_parallel_random_init
:
seed
=
seed
+
(
10
*
mpu
.
get_data_parallel_rank
())
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
device_count
()
>
0
:
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
else
:
raise
ValueError
(
'Seed ({}) should be a positive integer.'
.
format
(
seed
))
def
write_args_to_tensorboard
():
"""Write arguments to tensorboard."""
args
=
get_args
()
writer
=
get_tensorboard_writer
()
if
writer
:
for
arg
in
vars
(
args
):
writer
.
add_text
(
arg
,
str
(
getattr
(
args
,
arg
)),
global_step
=
args
.
iteration
)
def
_set_jit_fusion_options
():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
):
# nvfuser
torch
.
_C
.
_jit_set_profiling_executor
(
True
)
torch
.
_C
.
_jit_set_profiling_mode
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
False
)
torch
.
_C
.
_jit_set_texpr_fuser_enabled
(
False
)
torch
.
_C
.
_jit_set_nvfuser_enabled
(
True
)
torch
.
_C
.
_debug_set_autodiff_subgraph_inlining
(
False
)
else
:
# legacy pytorch fuser
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
3rdparty/Megatron-LM/megatron/learning_rates.py
0 → 100644
View file @
0211193c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Learning rate decay functions."""
import
math
from
megatron
import
print_rank_0
class
AnnealingLR
(
object
):
"""Anneals the learning rate."""
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
decay_style
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
# Class values.
self
.
optimizer
=
optimizer
self
.
max_lr
=
float
(
max_lr
)
self
.
min_lr
=
min_lr
assert
self
.
min_lr
>=
0.0
assert
self
.
max_lr
>=
self
.
min_lr
self
.
warmup_steps
=
warmup_steps
self
.
num_steps
=
0
self
.
decay_steps
=
decay_steps
assert
self
.
decay_steps
>
0
assert
self
.
warmup_steps
<
self
.
decay_steps
self
.
decay_style
=
decay_style
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
if
self
.
override_lr_scheduler
:
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
'use-checkpoint are set.'
# Set the learning rate
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_lr
(
self
):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
float
(
self
.
warmup_steps
)
# If the learning rate is constant, just return the initial value.
if
self
.
decay_style
==
'constant'
:
return
self
.
max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if
self
.
num_steps
>
self
.
decay_steps
:
return
self
.
min_lr
# If we are done with the warmup period, use the decay style.
num_steps_
=
self
.
num_steps
-
self
.
warmup_steps
decay_steps_
=
self
.
decay_steps
-
self
.
warmup_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
max_lr
-
self
.
min_lr
if
self
.
decay_style
==
'linear'
:
coeff
=
(
1.0
-
decay_ratio
)
elif
self
.
decay_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
else
:
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
decay_style
))
return
self
.
min_lr
+
coeff
*
delta_lr
def
step
(
self
,
increment
):
"""Set lr for all parameters groups."""
self
.
num_steps
+=
increment
new_lr
=
self
.
get_lr
()
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
def
state_dict
(
self
):
state_dict
=
{
'max_lr'
:
self
.
max_lr
,
'warmup_steps'
:
self
.
warmup_steps
,
'num_steps'
:
self
.
num_steps
,
'decay_style'
:
self
.
decay_style
,
'decay_steps'
:
self
.
decay_steps
,
'min_lr'
:
self
.
min_lr
}
return
state_dict
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if
self
.
override_lr_scheduler
:
print_rank_0
(
' > overriding {} value to {}'
.
format
(
name
,
cls_value
))
return
cls_value
if
not
self
.
use_checkpoint_lr_scheduler
:
assert
cls_value
==
sd_value
,
\
f
'AnnealingLR: class input value
{
cls_value
}
and checkpoint'
\
f
'value
{
sd_value
}
for
{
name
}
do not match'
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
return
sd_value
def
load_state_dict
(
self
,
sd
):
if
'start_lr'
in
sd
:
max_lr_
=
sd
[
'start_lr'
]
else
:
max_lr_
=
sd
[
'max_lr'
]
self
.
max_lr
=
self
.
_check_and_set
(
self
.
max_lr
,
max_lr_
,
'learning rate'
)
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
'minimum learning rate'
)
if
'warmup_iter'
in
sd
:
warmup_steps_
=
sd
[
'warmup_iter'
]
else
:
warmup_steps_
=
sd
[
'warmup_steps'
]
self
.
warmup_steps
=
self
.
_check_and_set
(
self
.
warmup_steps
,
warmup_steps_
,
'warmup iterations'
)
if
'end_iter'
in
sd
:
decay_steps_
=
sd
[
'end_iter'
]
else
:
decay_steps_
=
sd
[
'decay_steps'
]
self
.
decay_steps
=
self
.
_check_and_set
(
self
.
decay_steps
,
decay_steps_
,
'total number of iterations'
)
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
sd
[
'decay_style'
],
'decay style'
)
if
'num_iters'
in
sd
:
num_steps
=
sd
[
'num_iters'
]
else
:
num_steps
=
sd
[
'num_steps'
]
self
.
step
(
increment
=
num_steps
)
3rdparty/Megatron-LM/megatron/memory.py
0 → 100644
View file @
0211193c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
# A dictionary of all the memory buffers allocated.
_MEM_BUFFS
=
dict
()
def
allocate_mem_buff
(
name
,
numel
,
dtype
,
track_usage
):
"""Allocate a memory buffer."""
assert
name
not
in
_MEM_BUFFS
,
\
'memory buffer {} already allocated.'
.
format
(
name
)
_MEM_BUFFS
[
name
]
=
MemoryBuffer
(
name
,
numel
,
dtype
,
track_usage
)
return
_MEM_BUFFS
[
name
]
def
get_mem_buff
(
name
):
"""Get the memory buffer."""
return
_MEM_BUFFS
[
name
]
class
MemoryBuffer
:
"""Contiguous memory buffer.
Allocate a contiguous memory of type `dtype` and size `numel`. It is
used to reduce memory fragmentation.
Usage: After the allocation, the `_start` index is set tot the first
index of the memory. A memory chunk starting from `_start` index
can be `allocated` for an input tensor, with the elements of the
tensor being coppied. The buffer can be reused by resetting the
`_start` index.
"""
def
__init__
(
self
,
name
,
numel
,
dtype
,
track_usage
):
if
torch
.
distributed
.
get_rank
()
==
0
:
element_size
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
print
(
'> building the {} memory buffer with {} num elements '
'and {} dtype ({:.1f} MB)...'
.
format
(
name
,
numel
,
dtype
,
numel
*
element_size
/
1024
/
1024
),
flush
=
True
)
self
.
name
=
name
self
.
numel
=
numel
self
.
dtype
=
dtype
self
.
data
=
torch
.
empty
(
self
.
numel
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# Index tracking the start of the free memory.
self
.
_start
=
0
# Values used for tracking usage.
self
.
track_usage
=
track_usage
if
self
.
track_usage
:
self
.
in_use_value
=
0.0
self
.
total_value
=
0.0
def
reset
(
self
):
"""Reset the buffer start index to the beginning of the buffer."""
self
.
_start
=
0
def
is_in_use
(
self
):
"""Whether the current buffer hold on to any memory."""
return
self
.
_start
>
0
def
numel_in_use
(
self
):
"""Return number of elements in use."""
return
self
.
_start
def
add
(
self
,
tensor
):
"""Allocate a chunk of memory from the buffer to tensor and copy
the values."""
assert
tensor
.
dtype
==
self
.
dtype
,
\
'Input tensor type {} different from buffer type {}'
.
format
(
tensor
.
dtype
,
self
.
dtype
)
# Number of elements of the input tensor.
tensor_numel
=
torch
.
numel
(
tensor
)
new_start
=
self
.
_start
+
tensor_numel
assert
new_start
<=
self
.
numel
,
\
'Not enough memory left in the buffer ({} > {})'
.
format
(
tensor_numel
,
self
.
numel
-
self
.
_start
)
# New tensor is a view into the memory.
new_tensor
=
self
.
data
[
self
.
_start
:
new_start
]
self
.
_start
=
new_start
new_tensor
=
new_tensor
.
view
(
tensor
.
shape
)
new_tensor
.
copy_
(
tensor
)
# Return a pointer to the new tensor.
return
new_tensor
def
get_data
(
self
):
"""Return the data currently in use."""
if
self
.
track_usage
:
self
.
in_use_value
+=
float
(
self
.
_start
)
self
.
total_value
+=
float
(
self
.
numel
)
return
self
.
data
[:
self
.
_start
]
def
print_average_usage
(
self
):
"""Print memory usage average over time. We would like this value
to be as high as possible."""
assert
self
.
track_usage
,
'You need to enable track usage.'
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' > usage of {} memory buffer: {:.2f} %'
.
format
(
self
.
name
,
self
.
in_use_value
*
100.0
/
self
.
total_value
),
flush
=
True
)
class
RingMemBuffer
:
"""A ring of memory buffers."""
def
__init__
(
self
,
name
,
num_buffers
,
numel
,
dtype
,
track_usage
):
self
.
num_buffers
=
num_buffers
self
.
buffers
=
[
allocate_mem_buff
(
name
+
' {}'
.
format
(
i
),
numel
,
dtype
,
track_usage
)
for
i
in
range
(
num_buffers
)]
self
.
_index
=
-
1
def
get_next_buffer
(
self
):
self
.
_index
+=
1
self
.
_index
=
self
.
_index
%
self
.
num_buffers
buff
=
self
.
buffers
[
self
.
_index
]
assert
not
buff
.
is_in_use
(),
'buffer is already in use.'
return
buff
3rdparty/Megatron-LM/megatron/microbatches.py
0 → 100644
View file @
0211193c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron number of micro-batches calculators."""
from
abc
import
ABC
from
abc
import
abstractmethod
def
build_num_microbatches_calculator
(
args
):
# Constant num micro-batches.
if
args
.
rampup_batch_size
is
None
:
num_microbatches_calculator
=
ConstantNumMicroBatches
(
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
if
args
.
rank
==
0
:
print
(
'setting number of micro-batches to constant {}'
.
format
(
num_microbatches_calculator
.
get
()),
flush
=
True
)
else
:
assert
len
(
args
.
rampup_batch_size
)
==
3
,
'expected the following '
\
'format: --rampup-batch-size <start batch size> '
\
'<batch size incerement> <ramp-up samples>'
start_batch_size
=
int
(
args
.
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
args
.
rampup_batch_size
[
1
])
ramup_samples
=
int
(
args
.
rampup_batch_size
[
2
])
if
args
.
rank
==
0
:
print
(
'will use batch size rampup starting from global batch '
'size {} to global batch size {} with batch size increments '
'{} over {} samples.'
.
format
(
start_batch_size
,
args
.
global_batch_size
,
batch_size_increment
,
ramup_samples
),
flush
=
True
)
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatches
(
start_batch_size
,
batch_size_increment
,
ramup_samples
,
args
.
global_batch_size
,
args
.
micro_batch_size
,
args
.
data_parallel_size
)
return
num_microbatches_calculator
class
NumMicroBatchesCalculator
(
ABC
):
def
__init__
(
self
):
self
.
num_micro_batches
=
None
self
.
current_global_batch_size
=
None
def
get
(
self
):
return
self
.
num_micro_batches
def
get_current_global_batch_size
(
self
):
return
self
.
current_global_batch_size
@
abstractmethod
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
class
ConstantNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
micro_batch_times_data_parallel
=
micro_batch_size
*
\
data_parallel_size
assert
global_batch_size
%
micro_batch_times_data_parallel
==
0
,
\
'global batch size ({}) is not divisible by micro batch size ({})'
\
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
self
.
current_global_batch_size
=
global_batch_size
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
class
RampupBatchsizeNumMicroBatches
(
NumMicroBatchesCalculator
):
def
__init__
(
self
,
start_batch_size
,
batch_size_increment
,
ramup_samples
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
):
"""Batch size ramp up.
Over
steps = (global-batch-size - start-batch-size) / batch_size_increment
increment batch size from start-batch-size to global-batch-size using
rampup-samples / steps
samples.
Arguments:
start_batch_size: global batch size to start with
batch_size_increment: global batch size increments
ramup_samples: number of samples to use ramp up global
batch size from `start_batch_size` to `global_batch_size`
global_batch_size: global batch size post rampup
micro_batch_size: micro batch size
data_parallel_size: data parallel size.
"""
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
\
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
assert
start_batch_size
>
0
self
.
start_batch_size
=
start_batch_size
assert
global_batch_size
>
0
self
.
global_batch_size
=
global_batch_size
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_batch_size
assert
diff_batch_size
>=
0
assert
batch_size_increment
>
0
self
.
batch_size_increment
=
batch_size_increment
assert
diff_batch_size
%
batch_size_increment
==
0
,
'expected '
\
'global batch size interval ({}) to be divisible by global batch '
\
'size increment ({})'
.
format
(
diff_batch_size
,
batch_size_increment
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
self
.
ramup_samples
=
ramup_samples
assert
self
.
ramup_samples
>=
0
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
,
False
)
def
update
(
self
,
consumed_samples
,
consistency_check
):
if
consumed_samples
>
self
.
ramup_samples
:
self
.
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
self
.
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
if
consistency_check
:
assert
self
.
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
num_micro_batches
=
self
.
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
3rdparty/Megatron-LM/megatron/model/__init__.py
0 → 100644
View file @
0211193c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
from
.distributed
import
DistributedDataParallel
from
.bert_model
import
BertModel
from
.gpt_model
import
GPTModel
from
.t5_model
import
T5Model
from
.language_model
import
get_language_model
from
.module
import
Float16Module
from
.enums
import
ModelType
3rdparty/Megatron-LM/megatron/model/bert_model.py
0 → 100644
View file @
0211193c
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
import
torch
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
LayerNorm
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
scaled_init_method_normal
from
.module
import
MegatronModule
def
bert_extended_attention_mask
(
attention_mask
):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s
=
attention_mask
.
unsqueeze
(
1
)
# [b, s, 1]
attention_mask_bs1
=
attention_mask
.
unsqueeze
(
2
)
# [b, s, s]
attention_mask_bss
=
attention_mask_b1s
*
attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask
=
attention_mask_bss
.
unsqueeze
(
1
)
# Convert attention mask to binary:
extended_attention_mask
=
(
extended_attention_mask
<
0.5
)
return
extended_attention_mask
def
bert_position_ids
(
token_ids
):
# Create position ids
seq_length
=
token_ids
.
size
(
1
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
token_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
token_ids
)
return
position_ids
class
BertLMHead
(
MegatronModule
):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def
__init__
(
self
,
mpu_vocab_size
,
hidden_size
,
init_method
,
layernorm_epsilon
,
parallel_output
):
super
(
BertLMHead
,
self
).
__init__
()
args
=
get_args
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
mpu
.
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
1
)
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
gelu
=
erf_gelu
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
gelu
(
hidden_states
)
hidden_states
=
self
.
layernorm
(
hidden_states
)
output
=
parallel_lm_logits
(
hidden_states
,
word_embeddings_weight
,
self
.
parallel_output
,
bias
=
self
.
bias
)
return
output
def
post_language_model_processing
(
lm_output
,
pooled_output
,
lm_head
,
binary_head
,
lm_labels
,
logit_weights
,
fp16_lm_cross_entropy
):
# Output.
lm_logits
=
lm_head
(
lm_output
,
logit_weights
)
binary_logits
=
None
if
binary_head
is
not
None
:
binary_logits
=
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
return
lm_logits
,
binary_logits
else
:
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
return
lm_loss
,
binary_logits
class
BertModel
(
MegatronModule
):
"""Bert Language model."""
def
__init__
(
self
,
num_tokentypes
=
2
,
add_binary_head
=
True
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
init_method
=
init_method_normal
(
args
.
init_method_std
)
scaled_init_method
=
scaled_init_method_normal
(
args
.
init_method_std
,
args
.
num_layers
)
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
add_pooler
=
self
.
add_binary_head
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
init_method
=
init_method
,
scaled_init_method
=
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
if
self
.
post_process
:
self
.
lm_head
=
BertLMHead
(
self
.
word_embeddings_weight
().
size
(
0
),
args
.
hidden_size
,
init_method
,
args
.
layernorm_epsilon
,
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
self
.
binary_head
=
None
if
self
.
add_binary_head
:
self
.
binary_head
=
get_linear_layer
(
args
.
hidden_size
,
2
,
init_method
)
self
.
_binary_head_key
=
'binary_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
bert_model_input
,
attention_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
):
extended_attention_mask
=
bert_extended_attention_mask
(
attention_mask
)
input_ids
=
bert_model_input
position_ids
=
bert_position_ids
(
input_ids
)
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
extended_attention_mask
,
tokentype_ids
=
tokentype_ids
)
if
self
.
post_process
and
self
.
add_binary_head
:
lm_output
,
pooled_output
=
lm_output
else
:
pooled_output
=
None
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
pooled_output
,
self
.
lm_head
,
self
.
binary_head
,
lm_labels
,
self
.
word_embeddings_weight
(),
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
post_process
and
self
.
add_binary_head
:
state_dict_
[
self
.
_binary_head_key
]
\
=
self
.
binary_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
post_process
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
self
.
post_process
and
self
.
add_binary_head
:
self
.
binary_head
.
load_state_dict
(
state_dict
[
self
.
_binary_head_key
],
strict
=
strict
)
# Load word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment