Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
51a2e6b0
Commit
51a2e6b0
authored
Sep 11, 2020
by
Vijay Korthikanti
Committed by
Jared Casper
Sep 11, 2020
Browse files
Various speed optimizations.
parent
12518332
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1103 additions
and
163 deletions
+1103
-163
megatron/arguments.py
megatron/arguments.py
+13
-1
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+53
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+69
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+439
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+89
-0
megatron/model/fused_bias_gelu.py
megatron/model/fused_bias_gelu.py
+60
-0
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+94
-0
megatron/model/language_model.py
megatron/model/language_model.py
+2
-12
megatron/model/transformer.py
megatron/model/transformer.py
+244
-132
megatron/mpu/layers.py
megatron/mpu/layers.py
+25
-9
megatron/training.py
megatron/training.py
+12
-3
megatron/utils.py
megatron/utils.py
+3
-3
pretrain_gpt2.py
pretrain_gpt2.py
+0
-3
No files found.
megatron/arguments.py
View file @
51a2e6b0
...
...
@@ -19,7 +19,7 @@ import argparse
import
os
import
torch
from
megatron
import
fused_kernels
def
parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
...
...
@@ -118,6 +118,10 @@ def parse_args(extra_args_provider=None, defaults={},
'for distribute-checkpointed-activations to work you '
\
'need to enable checkpoint-activations'
# load scaled_upper_triang_masked_softmax_fusion kernel
if
args
.
scaled_upper_triang_masked_softmax_fusion
:
fused_kernels
.
load_scaled_upper_triang_masked_softmax_fusion_kernel
()
_print_args
(
args
)
return
args
...
...
@@ -221,6 +225,14 @@ def _add_training_args(parser):
'by this value.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
action
=
'store_true'
,
help
=
'Enable fusion of query_key_value_scaling '
'time (upper diagonal) masking, softmax.'
)
group
.
add_argument
(
'--bias-gelu-fusion'
,
action
=
'store_true'
,
help
=
'Enable bias and gelu fusion.'
)
group
.
add_argument
(
'--bias-dropout-fusion'
,
action
=
'store_true'
,
help
=
'Enable bias and dropout fusion.'
)
return
parser
...
...
megatron/fused_kernels/__init__.py
0 → 100644
View file @
51a2e6b0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pathlib
import
subprocess
from
torch.utils
import
cpp_extension
def
load_scaled_upper_triang_masked_softmax_fusion_kernel
():
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
return
raw_output
,
bare_metal_major
,
bare_metal_minor
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_major
,
_
=
get_cuda_bare_metal_version
(
cpp_extension
.
CUDA_HOME
)
if
int
(
bare_metal_major
)
>=
11
:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
srcpath
=
pathlib
.
Path
(
__file__
).
parent
.
absolute
()
scaled_upper_triang_masked_softmax_cuda
=
cpp_extension
.
load
(
name
=
'scaled_upper_triang_masked_softmax_cuda'
,
sources
=
[
srcpath
/
'scaled_upper_triang_masked_softmax.cpp'
,
srcpath
/
'scaled_upper_triang_masked_softmax_cuda.cu'
],
extra_cflags
=
[
'-O3'
,],
extra_cuda_cflags
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
cc_flag
,
verbose
=
True
)
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
0 → 100644
View file @
51a2e6b0
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
0 → 100644
View file @
51a2e6b0
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <cuda_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Implicit time (diagonal masking)
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
int
warp_iteration_limit
=
(
local_seq
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
stride
+
local_idx
;
dst
+=
first_batch
*
stride
+
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
elements
[
i
][
it
]
=
(
acc_t
)
src
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
scale
;
}
else
{
elements
[
i
][
it
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
if
(
it
<
warp_iteration_limit
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
local_seq
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
elements
[
i
][
it
]
/
sum
[
i
]);
}
else
if
(
element_index
<
element_count
)
{
dst
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
0
;
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_upper_triang_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
batch_size
,
int
stride
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
WARP_BATCH
+
blockIdx
.
x
;
int
local_seq
=
blockIdx
.
x
+
1
;
// batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
stride
+
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
local_seq
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
output_reg
[
i
][
it
]
=
output
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
];
}
else
{
output_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
grad_reg
[
i
][
it
]
=
(
acc_t
)
grad
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
*
output_reg
[
i
][
it
];
}
else
{
grad_reg
[
i
][
it
]
=
acc_t
(
0
);
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
int
element_index
=
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
gradInput
[
i
*
element_count
*
stride
+
it
*
WARP_SIZE
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
]
-
output_reg
[
i
][
it
]
*
sum
[
i
]));
}
}
}
}
}
// end of anonymous namespace
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_upper_triang_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
softmax_elements
,
int
softmax_elements_stride
,
int
attn_batches
)
{
TORCH_INTERNAL_ASSERT
(
softmax_elements
>=
0
&&
softmax_elements
<=
2048
);
if
(
softmax_elements
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
softmax_elements
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
seq_len
=
softmax_elements
;
int
batch_count
=
attn_batches
*
seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
attn_batches
%
batches_per_block
==
0
);
int
blocks_per_seq
=
attn_batches
/
batches_per_block
;
dim3
blocks
(
seq_len
,
blocks_per_seq
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
1
:
// 2
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
2
:
// 4
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
3
:
// 8
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
4
:
// 16
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
5
:
// 32
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
6
:
// 64
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
7
:
// 128
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
8
:
// 256
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
9
:
// 512
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
10
:
// 1024
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
case
11
:
// 2048
scaled_upper_triang_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
softmax_elements_stride
,
softmax_elements
);
break
;
default:
break
;
}
}
}
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
0 → 100644
View file @
51a2e6b0
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
input
.
size
(
0
);
const
int
seq_len
=
input
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
seq_len
<=
2048
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
seq_len
,
seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_upper_triang_masked_softmax_forward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
seq_len
=
output_grads
.
size
(
1
);
TORCH_INTERNAL_ASSERT
(
output_grads
.
size
(
1
)
==
output_grads
.
size
(
2
));
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/model/fused_bias_gelu.py
0 → 100644
View file @
51a2e6b0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
bias
,
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
megatron/model/fused_softmax.py
0 → 100644
View file @
51a2e6b0
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
)
:
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
\
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
\
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
class
FusedScaleMaskSoftmax
(
torch
.
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
upper_triang_mask
,
mask_func
,
softmax_in_fp32
,
scale
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
upper_triang_mask
=
upper_triang_mask
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
\
'softmax should be in fp32 when scaled'
def
forward
(
self
,
input
,
mask
):
# [b, np, s, s]
data_size
=
input
.
size
()
assert
input
.
dim
()
==
4
# invoke custom kernel for implicit uuper triangular masking
if
self
.
input_in_fp16
and
self
.
upper_triang_mask
and
\
data_size
[
-
1
]
<=
2048
and
input
.
size
()[
2
]
==
input
.
size
()[
3
]:
input
=
input
.
view
(
-
1
,
data_size
[
2
],
data_size
[
3
])
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
probs
=
probs
.
view
(
*
data_size
)
else
:
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
self
.
scale
is
not
None
:
mask_output
=
mask_output
*
self
.
scale
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
probs
=
probs
.
half
()
return
probs
megatron/model/language_model.py
View file @
51a2e6b0
...
...
@@ -22,7 +22,6 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron.module
import
MegatronModule
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
...
...
@@ -48,13 +47,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
"""Build language model and return along with the key to save."""
args
=
get_args
()
# Use torch gelu unless otherwise forced.
gelu
=
F
.
gelu
if
args
.
openai_gelu
:
gelu
=
openai_gelu
elif
args
.
onnx_safe
:
gelu
=
erf_gelu
if
init_method
is
None
:
init_method
=
init_method_normal
(
args
.
init_method_std
)
...
...
@@ -64,7 +56,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
# Language model.
language_model
=
TransformerLanguageModel
(
attention_mask_func
=
attention_mask_func
,
mlp_activation_func
=
gelu
,
init_method
=
init_method
,
output_layer_init_method
=
scaled_init_method
,
num_tokentypes
=
num_tokentypes
,
...
...
@@ -271,7 +262,6 @@ class TransformerLanguageModel(MegatronModule):
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
num_tokentypes
=
0
,
...
...
@@ -295,8 +285,8 @@ class TransformerLanguageModel(MegatronModule):
# Transformer
self
.
transformer
=
ParallelTransformer
(
attention_mask_func
,
mlp_activation_func
,
self
.
init_method
,
output_layer_init_method
)
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
)
self
.
_transformer_key
=
'transformer'
# Pooler
...
...
megatron/model/transformer.py
View file @
51a2e6b0
...
...
@@ -17,12 +17,21 @@
import
math
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.mpu
import
LayerNorm
from
megatron.module
import
MegatronModule
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
# flags required to enable jit fusion kernels
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
torch
.
_C
.
_jit_override_can_fuse_on_cpu
(
True
)
torch
.
_C
.
_jit_override_can_fuse_on_gpu
(
True
)
""" We use the following notation throughout this file:
h: hidden size
...
...
@@ -34,7 +43,7 @@ from megatron.module import MegatronModule
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [
b
,
s
, h] and returns a
Transformer takes input of size [
s
,
b
, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
...
...
@@ -45,7 +54,6 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
"""
class
ParallelMLP
(
MegatronModule
):
"""MLP.
...
...
@@ -55,8 +63,7 @@ class ParallelMLP(MegatronModule):
applied.
"""
def
__init__
(
self
,
mlp_activation_func
,
init_method
,
output_layer_init_method
):
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -65,29 +72,40 @@ class ParallelMLP(MegatronModule):
args
.
hidden_size
,
4
*
args
.
hidden_size
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
,
skip_bias_add
=
True
)
self
.
activation_func
=
mlp_activation_func
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
self
.
activation_func
=
F
.
gelu
if
args
.
openai_gelu
:
self
.
activation_func
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
# Project back to h.
self
.
dense_4h_to_h
=
mpu
.
RowParallelLinear
(
4
*
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
):
# [b, s, 4hp]
intermediate_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, 4hp]
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
# [b, s, h]
output
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
output
=
self
.
dropout
(
output
)
return
output
if
self
.
bias_gelu_fusion
:
intermediate_parallel
=
\
bias_gelu_impl
(
intermediate_parallel
,
bias_parallel
)
else
:
intermediate_parallel
=
\
self
.
activation_func
(
intermediate_parallel
+
bias_parallel
)
# [s, b, h]
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
,
output_bias
class
ParallelSelfAttention
(
MegatronModule
):
...
...
@@ -123,10 +141,22 @@ class ParallelSelfAttention(MegatronModule):
self
.
query_key_value
=
mpu
.
ColumnParallelLinear
(
args
.
hidden_size
,
3
*
args
.
hidden_size
,
stride
=
3
,
gather_output
=
False
,
init_method
=
init_method
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
args
.
scaled_upper_triang_masked_softmax_fusion
,
self
.
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
...
...
@@ -137,110 +167,85 @@ class ParallelSelfAttention(MegatronModule):
args
.
hidden_size
,
args
.
hidden_size
,
input_is_parallel
=
True
,
init_method
=
output_layer_init_method
)
self
.
output_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
def
_transpose_for_scores
(
self
,
tensor
):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape
=
tensor
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
tensor
=
tensor
.
view
(
*
new_tensor_shape
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
def
_get_query_key_value
(
self
,
hidden_states
):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
"""
# Attention heads. [b, s, hp]
mixed_x_layer
=
self
.
query_key_value
(
hidden_states
)
(
mixed_query_layer
,
mixed_key_layer
,
mixed_value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
# Reshape and transpose [b, np, s, hn]
query_layer
=
self
.
_transpose_for_scores
(
mixed_query_layer
)
key_layer
=
self
.
_transpose_for_scores
(
mixed_key_layer
)
value_layer
=
self
.
_transpose_for_scores
(
mixed_value_layer
)
return
query_layer
,
key_layer
,
value_layer
def
_get_unmasked_attention_scores
(
self
,
query_layer
,
key_layer
):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff
=
1
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
norm_factor
=
math
.
sqrt
(
coeff
*
math
.
sqrt
(
self
.
hidden_size_per_attention_head
))
# Raw attention scores. [b, np, s, s]
return
torch
.
matmul
(
query_layer
/
norm_factor
,
key_layer
.
transpose
(
-
1
,
-
2
)
/
norm_factor
)
def
_get_attention_probs
(
self
,
attention_scores
):
"""Attention probabilies with dropout. The output has
the size [b, np, s, s].
"""
# Attention probabilities. [b, np, s, s]
if
self
.
apply_query_key_layer_scaling
:
attention_scores
=
attention_scores
*
self
.
layer_number
attention_probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
attention_scores
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
return
attention_probs
def
_get_attended_context
(
self
,
attention_probs
,
value_layer
):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
# [b, np, s, hn]
context_layer
=
torch
.
matmul
(
attention_probs
,
value_layer
)
# [b, s, np, hn]
context_layer
=
context_layer
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
# [b, s, hp]
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [s, b, h]
return
context_layer
# =====================
# Query, Key, and Value
# =====================
def
_get_output
(
self
,
context_layer
):
"""Output layer with dropout."""
# Output. [b, s, h]
output
=
self
.
dense
(
context_layer
)
output
=
self
.
output_dropout
(
output
)
# Attention heads [s, b, hp] --> [s, b, 3 * hp]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
return
output
# [s, b, 3 * hp] --> [s, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [s, b, np, 3 * hn] --> 3 [s, b, np, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
mpu
.
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
):
# hidden_states: [b, s, h]
#
Attention heads. [b, np, s, hn]
query_layer
,
key_layer
,
value_layer
=
self
.
_get_query_key_value
(
hidden_states
)
#
==================================
# Adjust key and value for inference
# ==================================
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
key_layer
),
dim
=
-
2
)
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
value_layer
),
dim
=
-
2
)
value_layer
),
dim
=
0
)
if
get_key_value
:
present
=
(
key_layer
,
value_layer
)
# Raw attention scores. [b, np, s, s]
attention_scores
=
self
.
_get_unmasked_attention_scores
(
query_layer
,
key_layer
)
# fp32 conversion.
if
self
.
fp16
and
self
.
attention_softmax_in_fp32
:
attention_scores
=
attention_scores
.
float
()
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, s, s]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [s, b, np, hn] -> [s, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, s, s]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, s, s]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, s, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
#[b * np, hn, s]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, s, s]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# Update attention mask for inference. [b, np, s, s]
# ==================================================
# Apply attention mask. [b, np, s, s]
if
get_key_value
:
with
torch
.
no_grad
():
if
layer_past
is
not
None
:
...
...
@@ -253,26 +258,93 @@ class ParallelSelfAttention(MegatronModule):
...,
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
attention_scores
=
self
.
attention_mask_func
(
attention_scores
,
attention_mask
)
# Attention probabilities. [b, np, s, s]
attention_probs
=
self
.
_get_attention_probs
(
attention_scores
)
#
fp16 conversion
if
self
.
fp16
and
self
.
attention_softmax_in_fp32
:
attention_probs
=
attention_probs
.
half
()
#
===========================
# Attention probs and dropout
# ===========================
# Context layer. [b, s, hp]
context_layer
=
self
.
_get_attended_context
(
attention_probs
,
value_layer
)
# attention scores and attention mask [b, np, s, s]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# Output. [b, s, h]
output
=
self
.
_get_output
(
context_layer
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [s, b, hp]
# =========================
# value_layer -> context layer.
# [s, b, np, hn] --> [b, np, s, hn]
# context layer shape: [b, np, s, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
value_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [s, b * np, hn]
value_layer
=
value_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, s, s]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, s, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, s, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, s, hn] --> [s, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [s, b, np, hn] --> [s, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# Output. [s, b, h]
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
if
get_key_value
:
output
=
[
output
,
present
]
return
output
return
output
,
bias
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
:
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out
=
torch
.
nn
.
functional
.
dropout
(
x
+
bias
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
get_bias_dropout_add
(
training
):
def
_bias_dropout_add
(
x
,
bias
,
residual
,
prob
):
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
torch
.
jit
.
script
def
bias_dropout_add_fused_train
(
x
,
bias
,
residual
,
prob
)
:
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
def
bias_dropout_add_fused_inference
(
x
,
bias
,
residual
,
prob
)
:
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
class
ParallelTransformerLayer
(
MegatronModule
):
...
...
@@ -282,8 +354,8 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
"""
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
):
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
...
@@ -301,6 +373,8 @@ class ParallelTransformerLayer(MegatronModule):
self
.
attention
=
ParallelSelfAttention
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
# Layernorm on the input data.
self
.
post_attention_layernorm
=
LayerNorm
(
...
...
@@ -308,7 +382,7 @@ class ParallelTransformerLayer(MegatronModule):
eps
=
args
.
layernorm_epsilon
)
# MLP
self
.
mlp
=
ParallelMLP
(
mlp_activation_func
,
init_method
,
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
...
...
@@ -318,28 +392,60 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the begining of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
=
self
.
attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
attention_output
,
attention_bias
=
\
self
.
attention
(
layernorm_output
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
attention_output
,
presents
=
attention_output
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
layernorm_input
=
layernorm_output
+
attention_output
residual
=
layernorm_output
else
:
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
layernorm_input
=
hidden_states
+
attention_output
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
#re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
# MLP.
mlp_output
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
output
=
layernorm_output
+
mlp_output
residual
=
layernorm_output
else
:
output
=
layernorm_input
+
mlp_output
residual
=
layernorm_input
#re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
if
get_key_value
:
output
=
[
output
,
presents
]
...
...
@@ -350,7 +456,7 @@ class ParallelTransformerLayer(MegatronModule):
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
attention_mask_func
,
mlp_activation_func
,
def
__init__
(
self
,
attention_mask_func
,
init_method
,
output_layer_init_method
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -371,8 +477,8 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
attention_mask_func
,
mlp_activation_func
,
init_method
,
output_layer_init_method
,
layer_number
)
attention_mask_func
,
init_method
,
output_layer_init_method
,
layer_number
)
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
)
for
i
in
range
(
self
.
num_unique_layers
)])
...
...
@@ -435,6 +541,9 @@ class ParallelTransformer(MegatronModule):
'get_key_value does not work with '
\
'activation checkpointing'
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
checkpoint_activations
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
)
...
...
@@ -453,6 +562,9 @@ class ParallelTransformer(MegatronModule):
if
get_key_value
:
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
# reverting data format change [s b h] --> [b s h]
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
# Final layer norm.
output
=
self
.
final_layernorm
(
hidden_states
)
...
...
megatron/mpu/layers.py
View file @
51a2e6b0
...
...
@@ -54,7 +54,7 @@ def _initialize_affine_weight_gpu(weight, init_method,
weight
.
model_parallel
=
True
weight
.
partition_dim
=
partition_dim
weight
.
partition_stride
=
stride
with
get_cuda_rng_tracker
().
fork
():
init_method
(
weight
)
...
...
@@ -186,11 +186,15 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
# Keep input parameters
...
...
@@ -200,6 +204,7 @@ class ColumnParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size
=
get_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
...
...
@@ -245,13 +250,16 @@ class ColumnParallelLinear(torch.nn.Module):
# Set up backprop all-reduce.
input_parallel
=
copy_to_model_parallel_region
(
input_
)
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
self
.
bias
)
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_model_parallel_region
(
output_parallel
)
else
:
output
=
output_parallel
return
output
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
class
RowParallelLinear
(
torch
.
nn
.
Module
):
...
...
@@ -279,12 +287,16 @@ class RowParallelLinear(torch.nn.Module):
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
):
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
):
super
(
RowParallelLinear
,
self
).
__init__
()
# Keep input parameters
...
...
@@ -294,6 +306,7 @@ class RowParallelLinear(torch.nn.Module):
# Divide the weight matrix along the last dimension.
world_size
=
get_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
skip_bias_add
=
skip_bias_add
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
...
...
@@ -340,8 +353,11 @@ class RowParallelLinear(torch.nn.Module):
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
# All-reduce across all the partitions.
output_
=
reduce_from_model_parallel_region
(
output_parallel
)
if
self
.
bias
is
not
None
:
output
=
output_
+
self
.
bias
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
return
output
output_bias
=
self
.
bias
return
output
,
output_bias
megatron/training.py
View file @
51a2e6b0
...
...
@@ -236,29 +236,35 @@ def backward_step(optimizer, model, loss):
timers
=
get_timers
()
# Backward pass.
timers
(
'backward-backward'
).
start
()
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
if
args
.
fp16
:
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
else
:
loss
.
backward
()
timers
(
'backward-backward'
).
stop
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'allreduce'
).
start
()
timers
(
'
backward-
allreduce'
).
start
()
model
.
allreduce_params
(
reduce_after
=
False
,
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
timers
(
'
backward-
allreduce'
).
stop
()
# Update master gradients.
timers
(
'backward-master-grad'
).
start
()
if
args
.
fp16
:
optimizer
.
update_master_grads
()
timers
(
'backward-master-grad'
).
stop
()
# Clipping gradients helps prevent the exploding gradient.
timers
(
'backward-clip-grad'
).
start
()
if
args
.
clip_grad
>
0
:
if
not
args
.
fp16
:
mpu
.
clip_grad_norm
(
model
.
parameters
(),
args
.
clip_grad
)
else
:
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
timers
(
'backward-clip-grad'
).
stop
()
def
train_step
(
forward_step_func
,
data_iterator
,
...
...
@@ -311,7 +317,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
timers_to_log
.
append
(
name
)
add_to_logging
(
'forward'
)
add_to_logging
(
'backward'
)
add_to_logging
(
'allreduce'
)
add_to_logging
(
'backward-backward'
)
add_to_logging
(
'backward-allreduce'
)
add_to_logging
(
'backward-master-grad'
)
add_to_logging
(
'backward-clip-grad'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch generator'
)
...
...
megatron/utils.py
View file @
51a2e6b0
...
...
@@ -46,9 +46,9 @@ def report_memory(name):
torch
.
cuda
.
memory_allocated
()
/
mega_bytes
)
string
+=
' | max allocated: {}'
.
format
(
torch
.
cuda
.
max_memory_allocated
()
/
mega_bytes
)
string
+=
' |
cach
ed: {}'
.
format
(
torch
.
cuda
.
memory_
cach
ed
()
/
mega_bytes
)
string
+=
' | max
cach
ed: {}'
.
format
(
torch
.
cuda
.
max_memory_
cach
ed
()
/
mega_bytes
)
string
+=
' |
reserv
ed: {}'
.
format
(
torch
.
cuda
.
memory_
reserv
ed
()
/
mega_bytes
)
string
+=
' | max
reserv
ed: {}'
.
format
(
torch
.
cuda
.
max_memory_
reserv
ed
()
/
mega_bytes
)
print_rank_0
(
string
)
...
...
pretrain_gpt2.py
View file @
51a2e6b0
...
...
@@ -28,7 +28,6 @@ from megatron.training import pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
reduce_losses
def
model_provider
():
"""Build the model."""
...
...
@@ -80,10 +79,8 @@ def forward_step(data_iterator, model):
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
losses
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
loss_mask
=
loss_mask
.
view
(
-
1
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment