Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
91ee60df
Commit
91ee60df
authored
Sep 11, 2020
by
Jared Casper
Browse files
Merge branch 'vijay/softmax_fusion' into 'main'
Various speed optimizations. See merge request ADLR/megatron-lm!124
parents
12518332
51a2e6b0
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1103 additions
and
163 deletions
+1103
-163
megatron/arguments.py
megatron/arguments.py
+13
-1
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+53
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+69
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+439
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+89
-0
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+60
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+94
-0
megatron/model/language_model.py
megatron/model/language_model.py
+2
-12
megatron/model/transformer.py
megatron/model/transformer.py
+244
-132
megatron/mpu/layers.py
megatron/mpu/layers.py
+25
-9
megatron/training.py
megatron/training.py
+12
-3
megatron/utils.py
megatron/utils.py
+3
-3
pretrain_gpt2.py
pretrain_gpt2.py
+0
-3
No files found.
megatron/arguments.py
View file @
91ee60df
...
...
@@ -19,7 +19,7 @@ import argparse
import
os
import
torch
from
megatron
import
fused_kernels
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
...
...
@@ -118,6 +118,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
# load scaled_upper_triang_masked_softmax_fusion kernel
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
_print_args
(
args
)
return
args
...
...
@@ -221,6 +225,14 @@ def _add_training_args(parser):
'by this value.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
action
=
'store_true'
,
help
=
'Enable fusion of query_key_value_scaling '
'time (upper diagonal) masking, softmax.'
)
group
.
add_argument
(
'--bias-gelu-fusion'
,
action
=
'store_true'
,
help
=
'Enable bias and gelu fusion.'
)
group
.
add_argument
(
'--bias-dropout-fusion'
,
action
=
'store_true'
,
help
=
'Enable bias and dropout fusion.'
)
return
parser
...
...
megatron/fused_kernels/__init__.py
0 → 100644
View file @
91ee60df
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pathlib
import
subprocess
from
torch.utils
import
cpp_extension
def
load_scaled_upper_triang_masked_softmax_fusion_kernel
():
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
name
=
'scaled_upper_triang_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
],
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
,
verbose
=
True
)
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
0 → 100644
View file @
91ee60df
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
0 → 100644
View file @
91ee60df
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
scale
;
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
}
else
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
0
;
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
];
}
else
{
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
gradInput
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
0 → 100644
View file @
91ee60df
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_upper_triang_masked_softmax_forward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/model/fused_bias_gelu.py
0 → 100644
View file @
91ee60df
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
bias
,
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
megatron/model/fused_softmax.py
0 → 100644
View file @
91ee60df
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
\
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask
,
mask_func
,
softmax_in_fp32
,
scale
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask
=
upper_triang_mask
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
\
'softmax should be in fp32 when scaled'
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
data_size
=
input
.
size
()
assert
input
.
dim
()
==
4
# invoke custom kernel for implicit uuper triangular masking
if
self
.
input_in_fp16
and
self
.
upper_triang_mask
and
\
data_size
[
-
1
]
<=
2048
and
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
self
.
scale
is
not
None
:
mask_output
=
mask_output
*
self
.
scale
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
probs
=
probs
.
half
()
return
probs
megatron/model/language_model.py
View file @
91ee60df
...
...
@@ -22,7 +22,6 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -48,13 +47,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
"""Build language model and return along with the key to save."""
args
=
get_args
()
# Use torch gelu unless otherwise forced.
gelu
=
F
.
gelu
if
args
.
openai_gelu
:
gelu
=
openai_gelu
elif
args
.
onnx_safe
:
gelu
=
erf_gelu
if
init_method
is
None
:
init_method
=
init_method_normal
(
args
.
init_method_std
)
...
...
@@ -64,7 +56,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
# Language model.
language_model
=
TransformerLanguageModel
(
attention_mask_func
=
attention_mask_func
,
mlp_activation_func
=
gelu
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
num_tokentypes
=
num_tokentypes
,
...
...
@@ -271,7 +262,6 @@ class TransformerLanguageModel(MegatronModule):
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
...
...
@@ -295,8 +285,8 @@ class TransformerLanguageModel(MegatronModule):
# Transformer
self
.
transformer
=
ParallelTransformer
(
attention_mask_func
,
mlp_activation_func
,
self
.
init_method
,
output_layer_init_method
)
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
)
self
.
_transformer_key
=
'transformer'
# Pooler
...
...
megatron/model/transformer.py
View file @
91ee60df
...
...
@@ -17,12 +17,21 @@
import
math
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.mpu
import
LayerNorm
from
megatron.module
import
MegatronModule
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
# flags required to enable jit fusion kernels
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
""" We use the following notation throughout this file:
h: hidden size
...
...
@@ -34,7 +43,7 @@ from megatron.module import MegatronModule
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [
b
,
s
, h] and returns a
Transformer takes input of size [
s
,
b
, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
...
...
@@ -45,7 +54,6 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
"""
class
ParallelMLP
(
MegatronModule
):
"""MLP.
...
...
@@ -55,8 +63,7 @@ class ParallelMLP(MegatronModule):
applied.
"""
def
__init__
(
self
,
mlp_activation_func
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -65,29 +72,40 @@ class ParallelMLP(MegatronModule):
args
.
hidden_size
,
4
*
args
.
hidden_size
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
,
skip_bias_add
=
True
)
self
.
activation_func
=
mlp_activation_func
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
self
.
activation_func
=
F
.
gelu
if
args
.
openai_gelu
:
self
.
activation_func
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
# [b, s, 4hp]
intermediate_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, 4hp]
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
# [b, s, h]
output
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
output
=
self
.
dropout
(
output
)
return
output
if
self
.
bias_gelu_fusion
:
intermediate_parallel
=
\
bias_gelu_impl
(
intermediate_parallel
,
bias_parallel
)
else
:
intermediate_parallel
=
\
self
.
activation_func
(
intermediate_parallel
+
bias_parallel
)
# [s, b, h]
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
,
output_bias
class
ParallelSelfAttention
(
MegatronModule
):
...
...
@@ -123,10 +141,22 @@ class ParallelSelfAttention(MegatronModule):
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
args
.
hidden_size
,
stride
=
3
,
gather_output
=
False
,
init_method
=
init_method
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
self
.
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
...
...
@@ -137,110 +167,85 @@ class ParallelSelfAttention(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
def
_transpose_for_scores
(
self
,
tensor
):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape
=
tensor
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
tensor
=
tensor
.
view
(
*
new_tensor_shape
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
def
_get_query_key_value
(
self
,
hidden_states
):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
"""
# Attention heads. [b, s, hp]
mixed_x_layer
=
self
.
query_key_value
(
hidden_states
)
(
mixed_query_layer
,
mixed_key_layer
,
mixed_value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
# Reshape and transpose [b, np, s, hn]
query_layer
=
self
.
_transpose_for_scores
(
mixed_query_layer
)
key_layer
=
self
.
_transpose_for_scores
(
mixed_key_layer
)
value_layer
=
self
.
_transpose_for_scores
(
mixed_value_layer
)
return
query_layer
,
key_layer
,
value_layer
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff
=
1
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
norm_factor
=
math
.
sqrt
(
coeff
*
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
# Raw attention scores. [b, np, s, s]
return
torch
.
matmul
(
query_layer
/
norm_factor
,
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
def
_get_attention_probs
(
self
,
attention_scores
):
"""Attention probabilies with dropout. The output has
the size [b, np, s, s].
"""
# Attention probabilities. [b, np, s, s]
if
self
.
apply_query_key_layer_scaling
:
attention_scores
=
attention_scores
*
self
.
layer_number
attention_probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
attention_scores
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
return
attention_probs
def
_get_attended_context
(
self
,
attention_probs
,
value_layer
):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
# [b, np, s, hn]
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
# [b, s, np, hn]
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
# [b, s, hp]
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [s, b, h]
return
context_layer
# =====================
# Query, Key, and Value
# =====================
def
_get_output
(
self
,
context_layer
):
"""Output layer with dropout."""
# Output. [b, s, h]
output
=
self
.
dense
(
context_layer
)
output
=
self
.
output_dropout
(
output
)
# Attention heads [s, b, hp] --> [s, b, 3 * hp]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
return
output
# [s, b, 3 * hp] --> [s, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [s, b, np, 3 * hn] --> 3 [s, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [b, s, h]
#
Attention heads. [b, np, s, hn]
query_layer
,
key_layer
,
value_layer
=
self
.
_get_query_key_value
(
hidden_states
)
#
==================================
# Adjust key and value for inference
# ==================================
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
key_layer
),
dim
=
-
2
)
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
value_layer
),
dim
=
-
2
)
value_layer
),
dim
=
0
)
if
get_key_value
:
present
=
(
key_layer
,
value_layer
)
# Raw attention scores. [b, np, s, s]
attention_scores
=
self
.
_get_unmasked_attention_scores
(
query_layer
,
key_layer
)
# fp32 conversion.
if
self
.
fp16
and
self
.
attention_softmax_in_fp32
:
attention_scores
=
attention_scores
.
float
()
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, s, s]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [s, b, np, hn] -> [s, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, s, s]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, s, s]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, s, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#[b * np, hn, s]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, s, s]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# Update attention mask for inference. [b, np, s, s]
# ==================================================
# Apply attention mask. [b, np, s, s]
if
get_key_value
:
with
torch
.
no_grad
():
if
layer_past
is
not
None
:
...
...
@@ -253,26 +258,93 @@ class ParallelSelfAttention(MegatronModule):
...,
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
attention_scores
=
self
.
attention_mask_func
(
attention_scores
,
attention_mask
)
# Attention probabilities. [b, np, s, s]
attention_probs
=
self
.
_get_attention_probs
(
attention_scores
)
#
fp16 conversion
if
self
.
fp16
and
self
.
attention_softmax_in_fp32
:
attention_probs
=
attention_probs
.
half
()
#
===========================
# Attention probs and dropout
# ===========================
# Context layer. [b, s, hp]
context_layer
=
self
.
_get_attended_context
(
attention_probs
,
value_layer
)
# attention scores and attention mask [b, np, s, s]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# Output. [b, s, h]
output
=
self
.
_get_output
(
context_layer
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [s, b, hp]
# =========================
# value_layer -> context layer.
# [s, b, np, hn] --> [b, np, s, hn]
# context layer shape: [b, np, s, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
value_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [s, b * np, hn]
value_layer
=
value_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, s, s]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, s, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, s, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, s, hn] --> [s, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [s, b, np, hn] --> [s, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# Output. [s, b, h]
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
if
get_key_value
:
output
=
[
output
,
present
]
return
output
return
output
,
bias
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
:
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
get_bias_dropout_add
(
training
):
def
_bias_dropout_add
(
x
,
bias
,
residual
,
prob
):
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
,
bias
,
residual
,
prob
)
:
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
,
bias
,
residual
,
prob
)
:
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
class
ParallelTransformerLayer
(
MegatronModule
):
...
...
@@ -282,8 +354,8 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
"""
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
):
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
...
@@ -301,6 +373,8 @@ class ParallelTransformerLayer(MegatronModule):
self
.
attention
=
ParallelSelfAttention
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
# Layernorm on the input data.
self
.
post_attention_layernorm
=
LayerNorm
(
...
...
@@ -308,7 +382,7 @@ class ParallelTransformerLayer(MegatronModule):
eps
=
args
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
ParallelMLP
(
mlp_activation_func
,
init_method
,
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
...
...
@@ -318,28 +392,60 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the begining of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
=
self
.
attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
attention_output
,
attention_bias
=
\
self
.
attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
attention_output
,
presents
=
attention_output
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
layernorm_input
=
layernorm_output
+
attention_output
residual
=
layernorm_output
else
:
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
layernorm_input
=
hidden_states
+
attention_output
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
#re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
output
=
layernorm_output
+
mlp_output
residual
=
layernorm_output
else
:
output
=
layernorm_input
+
mlp_output
residual
=
layernorm_input
#re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
if
get_key_value
:
output
=
[
output
,
presents
]
...
...
@@ -350,7 +456,7 @@ class ParallelTransformerLayer(MegatronModule):
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -371,8 +477,8 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
)
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
)
for
i
in
range
(
self
.
num_unique_layers
)])
...
...
@@ -435,6 +541,9 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with '
\
'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
)
...
...
@@ -453,6 +562,9 @@ class ParallelTransformer(MegatronModule):
if
get_key_value
:
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
# reverting data format change [s b h] --> [b s h]
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
# Final layer norm.
output
=
self
.
final_layernorm
(
hidden_states
)
...
...
megatron/mpu/layers.py
View file @
91ee60df
...
...
@@ -54,7 +54,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
weight
.
model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
partition_stride
=
stride
with
get_cuda_rng_tracker
().
fork
():
init_method
(
weight
)
...
...
@@ -186,11 +186,15 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
# Keep input parameters
...
...
@@ -200,6 +204,7 @@ class ColumnParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size
=
get_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
...
...
@@ -245,13 +250,16 @@ class ColumnParallelLinear(torch.nn.Module):
# Set up backprop all-reduce.
input_parallel
=
copy_to_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
self
.
bias
)
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
return
output
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
class
RowParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -279,12 +287,16 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
):
super
(
RowParallelLinear
,
self
).
__init__
()
# Keep input parameters
...
...
@@ -294,6 +306,7 @@ class RowParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size
=
get_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
...
...
@@ -340,8 +353,11 @@ class RowParallelLinear(torch.nn.Module):
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
# All-reduce across all the partitions.
output_
=
reduce_from_model_parallel_region
(
output_parallel
)
if
self
.
bias
is
not
None
:
output
=
output_
+
self
.
bias
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
return
output
output_bias
=
self
.
bias
return
output
,
output_bias
megatron/training.py
View file @
91ee60df
...
...
@@ -236,29 +236,35 @@ def backward_step(optimizer, model, loss):
timers
=
get_timers
()
# Backward pass.
timers
(
'backward-backward'
).
start
()
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
else
:
loss
.
backward
()
timers
(
'backward-backward'
).
stop
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'allreduce'
).
start
()
timers
(
'
backward-
allreduce'
).
start
()
model
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
timers
(
'
backward-
allreduce'
).
stop
()
# Update master gradients.
timers
(
'backward-master-grad'
).
start
()
if
args
.
fp16
:
optimizer
.
update_master_grads
()
timers
(
'backward-master-grad'
).
stop
()
# Clipping gradients helps prevent the exploding gradient.
timers
(
'backward-clip-grad'
).
start
()
if
args
.
clip_grad
>
0
:
if
not
args
.
fp16
:
mpu
.
clip_grad_norm
(
model
.
parameters
(),
args
.
clip_grad
)
else
:
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
timers
(
'backward-clip-grad'
).
stop
()
def
train_step
(
forward_step_func
,
data_iterator
,
...
...
@@ -311,7 +317,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers_to_log
.
append
(
name
)
add_to_logging
(
'forward'
)
add_to_logging
(
'backward'
)
add_to_logging
(
'allreduce'
)
add_to_logging
(
'backward-backward'
)
add_to_logging
(
'backward-allreduce'
)
add_to_logging
(
'backward-master-grad'
)
add_to_logging
(
'backward-clip-grad'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch generator'
)
...
...
megatron/utils.py
View file @
91ee60df
...
...
@@ -46,9 +46,9 @@ def report_memory(name):
torch
.
cuda
.
memory_allocated
()
/
mega_bytes
)
string
+=
' | max allocated: {}'
.
format
(
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
' |
cach
ed: {}'
.
format
(
torch
.
cuda
.
memory_
cach
ed
()
/
mega_bytes
)
string
+=
' | max
cach
ed: {}'
.
format
(
torch
.
cuda
.
max_memory_
cach
ed
()
/
mega_bytes
)
string
+=
' |
reserv
ed: {}'
.
format
(
torch
.
cuda
.
memory_
reserv
ed
()
/
mega_bytes
)
string
+=
' | max
reserv
ed: {}'
.
format
(
torch
.
cuda
.
max_memory_
reserv
ed
()
/
mega_bytes
)
print_rank_0
(
string
)
...
...
pretrain_gpt2.py
View file @
91ee60df
...
...
@@ -28,7 +28,6 @@ from megatron.training import pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
reduce_losses
def
model_provider
():
"""Build the model."""
...
...
@@ -80,10 +79,8 @@ def forward_step(data_iterator, model):
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment