Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
051f58f1
Commit
051f58f1
authored
Mar 27, 2024
by
liangjing
Browse files
v1
parent
0024a5c6
Pipeline
#829
passed with stage
Changes
203
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
950 additions
and
184 deletions
+950
-184
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+341
-35
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+5
-6
megatron/optimizer_param_scheduler.py
megatron/optimizer_param_scheduler.py
+11
-3
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+3
-31
megatron/text_generation/tokenization.py
megatron/text_generation/tokenization.py
+4
-1
megatron/text_generation_server.py
megatron/text_generation_server.py
+2
-2
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+40
-5
megatron/training.py
megatron/training.py
+136
-74
pretrain_bert.py
pretrain_bert.py
+3
-2
pretrain_gpt.py
pretrain_gpt.py
+9
-5
pretrain_gpt_core.py
pretrain_gpt_core.py
+130
-0
pretrain_retro.py
pretrain_retro.py
+10
-7
pretrain_t5.py
pretrain_t5.py
+4
-3
pretrain_vision_classify.py
pretrain_vision_classify.py
+4
-2
pretrain_vision_dino.py
pretrain_vision_dino.py
+3
-1
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+4
-1
pyproject.toml
pyproject.toml
+18
-0
requirements.txt
requirements.txt
+13
-0
setup.py
setup.py
+107
-6
single.sh
single.sh
+103
-0
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Megatron distributed optimizer."""
from
apex.optimizers
import
FusedAdam
as
Adam
import
math
import
torch
...
...
@@ -28,6 +29,8 @@ class Range:
return
Range
(
start
,
start
+
self
.
size
)
def
__str__
(
self
):
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
def
__len__
(
self
):
return
self
.
end
-
self
.
start
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
...
...
@@ -206,27 +209,40 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
num_groups
=
len
(
param_groups
)
# Param group map.
param_group_map
=
{}
# World param group map.
# - Store a mapping of <model_parameter:group_index> for all parameters
# across all DP ranks. This is necessary because it is our first
# cross reference between the DDP mappings and the optimizer group
# parameters. This mapping only for use in the next step of building
# the local mapping over this DP rank's parameters.
world_param_group_map
=
{}
for
group_index
,
group
in
enumerate
(
param_groups
):
for
param
in
group
[
"params"
]:
assert
param
.
requires_grad
param_group_map
[
param
]
=
group_index
# Optimizer group ranges.
world_param_group_map
[
param
]
=
group_index
# Optimizer group ranges & param-group mapping.
# - Build a mapping from groups to their contained parameters, and also
# from parameters to their containing group index and order within
# the group. The group index and order are particularly important for
# saving and loading checkpoints.
local_param_group_map
=
{}
group_ranges
=
[
{
"params"
:
[]}
for
_
in
param_groups
]
for
model_gbuf_range_map
in
model_gbuf_ranges
:
for
dtype
,
gbuf_range_map
in
model_gbuf_range_map
.
items
():
for
param
in
gbuf_range_map
[
"param_map"
]:
group_index
=
param_group_map
[
param
]
group_index
=
world_
param_group_map
[
param
]
group_range
=
group_ranges
[
group_index
]
group_range
[
"params"
].
append
(
param
)
local_param_group_map
[
param
]
=
\
(
group_index
,
len
(
group_range
[
"params"
])
-
1
)
# Squeeze zero-size group ranges.
for
group_index
,
group_range
in
enumerate
(
group_ranges
):
group_range
[
"orig_group"
]
=
param_groups
[
group_index
]
group_range
s
=
[
g
for
g
in
group_ranges
if
len
(
g
[
"params"
])
>
0
]
group_range
[
"orig_group_idx"
]
=
param_groups
[
group_index
]
return
group_ranges
return
local_param_group_map
,
group_ranges
@
classmethod
...
...
@@ -318,7 +334,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
'Received {}'
.
format
(
model_
param
.
type
()))
# Update optimizer's params.
group_range
[
"orig_group"
][
"params"
]
=
[
...
...
@@ -356,6 +372,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Verify that contiguous buffers are being used.
# - Note: this should already be checked in arguments.py.
assert
use_contiguous_buffers_in_local_ddp
assert
isinstance
(
optimizer
,
Adam
),
\
"Only Adam currently supported, due to checkpointing requirements."
# Model grad buffer ranges.
self
.
model_gbuf_ranges
=
[]
...
...
@@ -365,9 +383,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
build_model_param_gbuf_map
(
self
.
model_gbuf_ranges
)
# Optimizer ranges.
self
.
opt_group_ranges
=
self
.
build_optimizer
_group_ranges
(
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_ranges
)
self
.
model_param_group_index_map
,
self
.
opt
_group_ranges
=
\
self
.
build_optimizer_group_ranges
(
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_ranges
)
# Allocate main param shards.
(
...
...
@@ -388,9 +406,18 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
model_index
,
model
in
enumerate
(
self
.
models
):
current_param_buffers
=
{}
for
dtype
,
grad_buffer
in
model
.
_grad_buffers
.
items
():
param_buffer
=
torch
.
tensor
(
grad_buffer
.
data
.
storage
().
_untyped
(),
dtype
=
params_dtype
,
device
=
grad_buffer
.
data
.
device
)
# Handle older/newer method for getting untyped storage.
try
:
storage
=
grad_buffer
.
data
.
storage
().
_untyped
()
except
:
storage
=
grad_buffer
.
data
.
storage
().
untyped
()
# Typed param buffer.
param_buffer
=
torch
.
tensor
(
storage
,
dtype
=
params_dtype
,
device
=
grad_buffer
.
data
.
device
)
param_buffer
=
param_buffer
[:
grad_buffer
.
numel_padded
]
current_param_buffers
[
dtype
]
=
param_buffer
self
.
param_buffers
.
append
(
current_param_buffers
)
...
...
@@ -424,29 +451,108 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
state_dict
(
self
):
"""
The state dict must contain the fp32-from-float16 shards.
The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
related) optimizer variables. The returned state dict can be stored in
the standard model/RNG checkpoint file. The parameter and dependent
optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate
checkpoint file by calling 'save_parameter_state()'.
"""
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
# Optimizer state (do not store parameter state here).
state_dict
[
'optimizer'
]
=
{
k
:
v
for
k
,
v
in
self
.
optimizer
.
state_dict
().
items
()
if
k
!=
"state"
}
for
param_group
in
state_dict
[
"optimizer"
][
"param_groups"
]:
del
param_group
[
"params"
]
# Grad scaler state.
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'shard_fp32_from_float16_groups'
]
=
\
self
.
shard_fp32_from_float16_groups
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""
Load the state dict.
"""Load the state dict.
As detailed in state_dict(), the state dict contains all non-
parameter-related variables. This method is notably longer than
state_dict(), because the Torch optimizers state has yet to be
allocated at this point, and so we must do a cross referencing between
the optimizers state (and the ordering it expects for parameter state)
and this DP rank's shards. The optimizer at this point does not contain
any tensor dimension information, so we must get these dimensions from
the DP shards mapped during DistributedOptimizer.__init__().
The tensor parameter state is loaded via load_parameter_state(), and
so this method also must populate the loaded state dict with dummy
tensor data (i.e., via torch.empty() below). This will be overwritten
during load_parameter_state().
** Note: Torch optimizer's state structure. **
The Torch optimizer stores its state in two levels. The top level is a
list of groups, where each group contains a list of integer indexes
(corresponding to parameters) that index into a master parameter list
that is shared by all groups. As such, three values are necessary for
maintaining this ordering:
- group_index : The group to which a parameter belongs.
- group_order : The index of a parameter within its group.
- state_order : The index of a parameter within the shared parameter
list.
"""
# Get the Torch optimizer's state dict.
# - This 'inner' optimizer at this point is unallocated, and only
# contains an integer odering of parameters within each group, and
# the ordering of parameters within its flattened parameter state
# list.
inner_state_dict
=
self
.
optimizer
.
state_dict
()
state_dict_param_groups
=
[{
**
group
,
"params"
:
list
(
inner_state_dict
[
"param_groups"
][
idx
][
"params"
]),
}
for
idx
,
group
in
enumerate
(
state_dict
[
"optimizer"
][
"param_groups"
])]
# Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
# - Real data is overwritten during load_parameter_state().
state_dict_state
=
[]
for
gbuf_range_maps
in
self
.
model_gbuf_ranges
:
for
gbuf_range_map
in
gbuf_range_maps
.
values
():
for
model_param
,
param_range_map
in
\
gbuf_range_map
[
"param_map"
].
items
():
# Get parameter ordering information (see method docstring
# for details).
group_index
,
group_order
=
\
self
.
model_param_group_index_map
[
model_param
]
state_order
=
inner_state_dict
[
"param_groups"
]
\
[
group_index
][
"params"
][
group_order
]
# Allocate dummy tensors.
numel
=
len
(
param_range_map
[
"gbuf_world"
])
init_shard
=
lambda
:
torch
.
empty
(
(
numel
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
state_dict_state
.
append
((
state_order
,
{
"exp_avg"
:
init_shard
(),
"exp_avg_sq"
:
init_shard
(),
}))
# Sort by state order (see method docstring for details).
state_dict_state
.
sort
(
key
=
lambda
s
:
s
[
0
])
state_dict_state
=
{
s
[
0
]:
s
[
1
]
for
s
in
state_dict_state
}
# Optimizer.
optimizer_key
=
'optimizer'
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
self
.
optimizer
.
load_state_dict
({
"state"
:
state_dict_state
,
"param_groups"
:
state_dict_param_groups
,
})
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
...
...
@@ -461,12 +567,180 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
for
current_group
,
saved_group
in
zip
(
self
.
shard_fp32_from_float16_groups
,
state_dict
[
"shard_fp32_from_float16_groups"
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
save_parameter_state
(
self
,
filename
):
"""Save parameter state (i.e., parameter & optimizer tensors).
This method performs three steps:
- For each DP rank, copy param & optimizer shards to contiguous CPU
buffers. (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
- Gather contiguous buffers on DP rank 0 and concatenate to world
buffers.
- Save world buffers to disk (i.e., distrib_opt.pt).
"""
# Data parallelism variables.
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group_gloo
=
mpu
.
get_data_parallel_group_gloo
()
data_parallel_global_ranks
=
list
(
mpu
.
_DATA_PARALLEL_GLOBAL_RANKS
)
# Collect param states.
state
=
{}
for
model_idx
,
gbuf_range_maps
in
enumerate
(
self
.
model_gbuf_ranges
):
# Iterate grad buffers (by data type).
dtype_state
=
{}
assert
len
(
gbuf_range_maps
)
==
1
,
"single dtype supported, for now."
for
dtype
,
gbuf_range_map
in
gbuf_range_maps
.
items
():
# Compute local DP contiguous shard's size.
model
=
self
.
models
[
model_idx
]
gbuf_world_numel
=
model
.
_grad_buffers
[
dtype
].
numel_padded
gbuf_local_numel
=
int
(
gbuf_world_numel
/
data_parallel_world_size
)
local_shards
=
{
key
:
torch
.
empty
((
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
)}
# Build contiguous DP rank shards (for param + optim states).
for
model_param
,
param_range_map
in
\
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
\
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
\
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"param"
:
main_param
,
**
optim_state
,
}
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
for
key
in
local_shards
:
local_shards
[
key
][
gbuf_local_start
:
gbuf_local_end
]
\
.
data
.
copy_
(
tensors
[
key
].
detach
().
cpu
())
# Gather contiguous shards on DP rank 0.
world_tensors
=
{}
for
key
,
send_tensor
in
local_shards
.
items
():
# Gather tensor list.
if
data_parallel_rank
==
0
:
recv_tensors
=
[
torch
.
empty
((
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
_
in
range
(
data_parallel_world_size
)]
else
:
recv_tensors
=
None
# Gather.
torch
.
distributed
.
gather
(
send_tensor
,
recv_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Concatenate.
if
data_parallel_rank
==
0
:
world_tensors
[
key
]
=
torch
.
cat
(
recv_tensors
)
# Collect world state.
dtype_state
[
dtype
]
=
world_tensors
state
[
model_idx
]
=
dtype_state
# Save param state.
if
data_parallel_rank
==
0
:
torch
.
save
(
state
,
filename
)
def
load_parameter_state
(
self
,
filename
):
"""Load parameter state (i.e., parameter & optimizer tensors).
This method performs the reverse of save_parameter_state():
- Load world buffers from disk (i.e., distrib_opt.pt).
- Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
rank receives its relevant subset of the world buffers).
- For each DP rank, copy param & optimizer shards from contiguous CPU
buffers. (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
"""
# Data parallelism variables.
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group_gloo
=
mpu
.
get_data_parallel_group_gloo
()
data_parallel_global_ranks
=
list
(
mpu
.
_DATA_PARALLEL_GLOBAL_RANKS
)
# Load on DP rank 0.
if
data_parallel_rank
==
0
:
loaded_state
=
torch
.
load
(
filename
)
# Scatter tensors to all DP ranks.
for
model_idx
,
gbuf_range_maps
in
enumerate
(
self
.
model_gbuf_ranges
):
for
dtype
,
gbuf_range_map
in
gbuf_range_maps
.
items
():
# Compute local DP contiguous shard's size.
model
=
self
.
models
[
model_idx
]
gbuf_world_numel
=
model
.
_grad_buffers
[
dtype
].
numel_padded
gbuf_local_numel
=
int
(
gbuf_world_numel
/
data_parallel_world_size
)
# Contiguous local shards (received from DP rank 0).
local_shards
=
{
key
:
torch
.
empty
((
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
)}
# Scatter local shards from DP rank 0.
for
key
,
recv_tensor
in
local_shards
.
items
():
# Scatter tensor list.
if
data_parallel_rank
==
0
:
world_tensor
=
loaded_state
[
model_idx
][
dtype
][
key
]
gbuf_start_idxs
=
\
list
(
range
(
0
,
gbuf_world_numel
,
gbuf_local_numel
))
send_tensors
=
[
world_tensor
[
i
:(
i
+
gbuf_local_numel
)]
for
i
in
gbuf_start_idxs
]
else
:
send_tensors
=
None
# Scatter.
torch
.
distributed
.
scatter
(
recv_tensor
,
send_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Copy local contiguous shards to param/optim shards.
for
model_param
,
param_range_map
in
\
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
\
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
\
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"param"
:
main_param
,
**
optim_state
,
}
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
for
key
in
local_shards
:
tensors
[
key
].
data
.
copy_
(
local_shards
[
key
][
gbuf_local_start
:
gbuf_local_end
])
def
zero_grad
(
self
,
set_to_none
=
True
):
...
...
@@ -583,6 +857,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers
(
'grads-reduce-scatter'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
"""
All-gather updated model params.
...
...
@@ -617,9 +892,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Copy from param buffer to each param.
for
model_id
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
param
,
buf_
range
in
param_map
.
items
():
for
param
,
(
buf_
start
,
buf_end
)
in
param_map
.
items
():
param_buf
=
self
.
param_buffers
[
model_id
][
dtype
]
param_buf_shard
=
param_buf
[
buf_
range
[
0
]:
buf_range
[
1
]
]
param_buf_shard
=
param_buf
[
buf_
start
:
buf_end
]
param
.
view
(
-
1
).
detach
().
copy_
(
param_buf_shard
)
timers
(
'params-all-gather'
).
stop
()
...
...
@@ -717,3 +992,34 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
model_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
self
.
model_fp32_groups
)
def
_copy_model_params_to_main_params
(
self
):
"""
Copy model params to main params.
During finetuning, this method is used to reload the main params from
the model params. This copy does not make use of the grad buffer as
an intermediary.
"""
# Utility method for copying group params.
def
copy_group_params
(
model_groups
,
shard_main_groups
):
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
for
model_param
,
shard_main_param
in
zip
(
model_group
,
shard_main_group
):
param_range_map
=
self
.
get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
shard_model_param
=
model_param
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
data
.
copy_
(
shard_model_param
)
# Copy model groups to shard groups.
copy_group_params
(
self
.
model_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_params
(
self
.
model_fp32_groups
,
self
.
shard_fp32_groups
)
megatron/optimizer/optimizer.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Megatron optimizer."""
...
...
@@ -219,12 +219,12 @@ class MegatronOptimizer(ABC):
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_
word_
embeddings
:
word_embeddings_
weight
=
unwrapped_model
.
wor
d_embedding
s
_weight
()
if
unwrapped_model
.
share_embeddings
_and_output_weights
:
weight
=
unwrapped_model
.
share
d_embedding
_or_output
_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_
weight
.
main_grad
grad
=
weight
.
main_grad
else
:
grad
=
word_embeddings_
weight
.
grad
grad
=
weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
...
...
@@ -275,7 +275,6 @@ class MegatronOptimizer(ABC):
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
reduce_model_grads
(
self
,
args
,
timers
):
"""All-reduce all grads, and all-reduce embeddings."""
...
...
megatron/optimizer_param_scheduler.py
View file @
051f58f1
...
...
@@ -9,7 +9,7 @@ from megatron import print_rank_0
class
OptimizerParamScheduler
(
object
):
"""Anneals learning rate and weight decay"""
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
def
__init__
(
self
,
optimizer
,
init_lr
,
max_lr
,
min_lr
,
lr_warmup_steps
,
lr_decay_steps
,
lr_decay_style
,
start_wd
,
end_wd
,
wd_incr_steps
,
wd_incr_style
,
use_checkpoint_opt_param_scheduler
=
True
,
...
...
@@ -18,10 +18,12 @@ class OptimizerParamScheduler(object):
# Class values.
self
.
optimizer
=
optimizer
self
.
init_lr
=
init_lr
self
.
max_lr
=
float
(
max_lr
)
self
.
min_lr
=
min_lr
assert
self
.
min_lr
>=
0.0
assert
self
.
max_lr
>=
self
.
min_lr
assert
self
.
init_lr
<=
self
.
max_lr
self
.
lr_warmup_steps
=
lr_warmup_steps
self
.
num_steps
=
0
...
...
@@ -80,8 +82,14 @@ class OptimizerParamScheduler(object):
# Use linear warmup for the initial part.
if
self
.
lr_warmup_steps
>
0
and
self
.
num_steps
<=
self
.
lr_warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
float
(
self
.
lr_warmup_steps
)
return
(
self
.
init_lr
+
(
(
self
.
max_lr
-
self
.
init_lr
)
*
float
(
self
.
num_steps
)
/
float
(
self
.
lr_warmup_steps
)
)
)
# If the learning rate is constant, just return the initial value.
if
self
.
lr_decay_style
==
'constant'
:
...
...
megatron/text_generation/forward_step.py
View file @
051f58f1
...
...
@@ -7,46 +7,18 @@ from collections.abc import Iterable
import
torch
from
megatron
import
get_args
from
megatron.core
import
mpu
from
megatron.core
import
mpu
,
InferenceParams
from
.communication
import
(
send_to_next_pipeline_rank
,
recv_from_prev_pipeline_rank_
)
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def
__init__
(
self
,
max_batch_size
,
max_sequence_len
):
"""Note that offsets are set to zero and we always set the
flag to allocate memory. After the first call, make sure to
set this flag to False."""
self
.
max_sequence_len
=
max_sequence_len
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
key_value_memory_dict
=
{}
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
assert
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
## make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
new_inference_value_memory
)
class
ForwardStep
:
"""Forward step function with all the communications.
We use a class here to hide the inference parameters
from the outside caller."""
def
__init__
(
self
,
model
,
max_batch_size
,
max_sequence_len
):
def
__init__
(
self
,
model
,
max_batch_size
,
max_sequence_len
gth
):
"""Set values so we don't need to do it multiple times."""
# Make sure model is in eval mode.
assert
not
isinstance
(
model
,
Iterable
),
\
...
...
@@ -55,7 +27,7 @@ class ForwardStep:
self
.
model
=
model
# Initialize inference parameters.
self
.
inference_params
=
InferenceParams
(
max_batch_size
,
max_sequence_len
)
max_sequence_len
gth
)
# Pipelining arguments.
args
=
get_args
()
self
.
pipeline_size_larger_than_one
=
(
...
...
megatron/text_generation/tokenization.py
View file @
051f58f1
...
...
@@ -30,8 +30,11 @@ def detokenize_generations(tokens_gpu_tensor,
if
return_segments
:
words
=
[]
for
token
in
sequence_tokens
:
if
args
.
tokenizer_type
in
[
'SentencePieceTokenizer'
,
'GPTSentencePieceTokenizer'
]:
if
args
.
tokenizer_type
in
[
'SentencePieceTokenizer'
,
'GPTSentencePieceTokenizer'
]:
word
=
tokenizer
.
decoder
[
token
]
elif
args
.
tokenizer_type
==
'NullTokenizer'
:
word
=
str
(
token
)
else
:
word
=
tokenizer
.
tokenizer
.
decoder
[
token
]
word
=
bytearray
(
...
...
megatron/text_generation_server.py
View file @
051f58f1
...
...
@@ -237,5 +237,5 @@ class MegatronServer(object):
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/api'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
True
,
debug
=
False
)
def
run
(
self
,
url
,
port
):
self
.
app
.
run
(
url
,
threaded
=
True
,
debug
=
False
,
port
=
port
)
megatron/tokenizer/tokenizer.py
View file @
051f58f1
...
...
@@ -15,19 +15,19 @@ def build_tokenizer(args):
print
(
'> building {} tokenizer ...'
.
format
(
args
.
tokenizer_type
),
flush
=
True
)
if
args
.
tokenizer_type
not
in
[
'SentencePieceTokenizer'
,
'GPTSentencePieceTokenizer'
]:
assert
args
.
vocab_file
is
not
None
# Select and instantiate the tokenizer.
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
assert
args
.
vocab_file
is
not
None
tokenizer
=
_BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
,
vocab_extra_ids
=
args
.
vocab_extra_ids
)
elif
args
.
tokenizer_type
==
'BertWordPieceCase'
:
assert
args
.
vocab_file
is
not
None
tokenizer
=
_BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
,
vocab_extra_ids
=
args
.
vocab_extra_ids
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
vocab_file
is
not
None
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
elif
args
.
tokenizer_type
==
'SentencePieceTokenizer'
:
...
...
@@ -36,10 +36,13 @@ def build_tokenizer(args):
elif
args
.
tokenizer_type
==
'GPTSentencePieceTokenizer'
:
assert
args
.
tokenizer_model
is
not
None
tokenizer
=
_GPTSentencePieceTokenizer
(
args
.
tokenizer_model
)
elif
args
.
tokenizer_type
==
'NullTokenizer'
:
assert
args
.
vocab_size
is
not
None
tokenizer
=
_NullTokenizer
(
args
.
vocab_size
)
else
:
raise
NotImplementedError
(
'{} tokenizer is not '
'implemented.'
.
format
(
args
.
tokenizer_type
))
# Add vocab size.
args
.
padded_vocab_size
=
_vocab_size_with_padding
(
tokenizer
.
vocab_size
,
args
)
...
...
@@ -460,7 +463,6 @@ class _SentencePieceTokenizer(AbstractTokenizer):
def
additional_special_tokens_ids
(
self
):
return
[
self
.
vocab
[
k
]
for
k
in
self
.
_t5_tokens
]
class
_GPTSentencePieceTokenizer
(
_SentencePieceTokenizer
):
"""SentencePieceTokenizer-Megatron wrapper"""
...
...
@@ -499,3 +501,36 @@ class _GPTSentencePieceTokenizer(_SentencePieceTokenizer):
@
property
def
additional_special_tokens_ids
(
self
):
return
None
class
_NullTokenizer
:
def
__init__
(
self
,
vocab_size
):
vocab_size
=
int
(
vocab_size
)
self
.
_eos_id
=
vocab_size
self
.
vocab_size
=
vocab_size
+
1
def
tokenize
(
self
,
text
):
return
[
int
(
x
)
for
x
in
text
.
split
(
' '
)]
def
detokenize
(
self
,
ids
):
text
=
[
str
(
x
)
for
x
in
ids
]
return
' '
.
join
(
text
)
@
property
def
cls
(
self
):
return
-
1
@
property
def
sep
(
self
):
return
-
1
@
property
def
mask
(
self
):
return
-
1
@
property
def
eod
(
self
):
return
self
.
_eos_id
@
property
def
additional_special_tokens_ids
(
self
):
return
None
megatron/training.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Pretrain utilities."""
...
...
@@ -20,6 +20,7 @@ from megatron import get_num_microbatches
from
megatron
import
is_last_rank
from
megatron
import
update_num_microbatches
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.utils
import
get_model_config
from
megatron
import
print_rank_0
from
megatron
import
print_rank_last
from
megatron.checkpointing
import
load_checkpoint
...
...
@@ -113,6 +114,7 @@ def pretrain(train_valid_test_dataset_provider,
timers
(
'model-and-optimizer-setup'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
config
=
get_model_config
(
model
[
0
])
# Data stuff.
timers
(
'train/valid/test-data-iterators-setup'
,
log_level
=
0
).
start
(
...
...
@@ -140,38 +142,44 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0
(
'done with setup ...'
)
timers
.
log
([
'model-and-optimizer-setup'
,
'train/valid/test-data-iterators-setup'
],
barrier
=
True
)
print_rank_0
(
'training ...'
)
iteration
=
0
if
not
args
.
skip_train
:
print_rank_0
(
'training ...'
)
if
args
.
dataloader_type
==
'cyclic'
and
args
.
retro_add_retriever
:
args
.
train_iters
=
args
.
retro_cyclic_train_iters
print_rank_0
(
"retro cyclic train iters : %d"
%
args
.
train_iters
)
if
args
.
dataloader_type
==
'cyclic'
and
args
.
retro_add_retriever
:
args
.
train_iters
=
args
.
retro_cyclic_train_iters
print_rank_0
(
"retro cyclic train iters : %d"
%
args
.
train_iters
)
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
)
print_datetime
(
'after training is done'
)
iteration
=
0
if
args
.
do_train
and
args
.
train_iters
>
0
:
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
,
config
)
print_datetime
(
'after training is done'
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
else
:
print_rank_0
(
'skipping training (--skip-train is on) ...'
)
iteration
=
args
.
iteration
if
args
.
do_valid
:
prefix
=
'the end of training for
val
dat
a
'
prefix
=
f
'iteration
{
iteration
}
on
val
i
dat
ion set
'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
False
)
if
args
.
save
and
iteration
!=
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
True
,
write_to_tensorboard
=
not
args
.
skip_train
)
if
args
.
do_test
:
# Run on test data.
prefix
=
'the end of training for test data'
prefix
=
f
'iteration
{
iteration
}
on test set'
evaluate_and_print_results
(
prefix
,
forward_step_func
,
test_data_iterator
,
model
,
0
,
process_non_loss_data_func
,
True
)
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
True
,
write_to_tensorboard
=
not
args
.
skip_train
)
def
update_train_iters
(
args
):
...
...
@@ -345,6 +353,7 @@ def get_optimizer_param_scheduler(optimizer):
opt_param_scheduler
=
OptimizerParamScheduler
(
optimizer
,
init_lr
=
args
.
lr_warmup_init
,
max_lr
=
args
.
lr
,
min_lr
=
args
.
min_lr
,
lr_warmup_steps
=
lr_warmup_steps
,
...
...
@@ -402,7 +411,7 @@ def setup_model_and_optimizer(model_provider_func,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
):
model
,
optimizer
,
opt_param_scheduler
,
config
):
"""Single training step."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -417,18 +426,24 @@ def train_step(forward_step_func, data_iterator,
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
forward_backward_func
=
get_forward_backward_func
()
fwd_bwd_timers
=
timers
if
args
.
timing_log_level
>
1
else
None
# set timers to None if none of the timers in fwd_bwd are active, just to save the checks
if
args
.
timing_log_level
<
2
:
config
.
timers
=
None
losses_reduced
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
dtype
=
args
.
params_dtype
,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
grad_scaler
=
optimizer
.
scale_loss
,
sequence_parallel
=
args
.
sequence_parallel
,
forward_only
=
False
,
timers
=
fwd_bwd_timers
)
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
False
)
# reset timers if necessary
if
config
.
timers
is
None
:
config
.
timers
=
timers
timers
(
'forward-backward'
).
stop
()
# Empty unused memory.
...
...
@@ -671,7 +686,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
def
train
(
forward_step_func
,
model
,
optimizer
,
opt_param_scheduler
,
train_data_iterator
,
valid_data_iterator
,
process_non_loss_data_func
):
process_non_loss_data_func
,
config
):
"""Train the model function."""
args
=
get_args
()
timers
=
get_timers
()
...
...
@@ -689,10 +704,20 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
iteration
=
args
.
iteration
# Setup some training config params
config
.
grad_scale_func
=
optimizer
.
scale_loss
config
.
timers
=
timers
timers
(
'interval-time'
,
log_level
=
0
).
start
(
barrier
=
True
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
if
args
.
profile
and
\
iteration
==
args
.
profile_step_start
and
\
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
torch
.
autograd
.
profiler
.
emit_nvtx
(
record_shapes
=
True
).
__enter__
()
update_num_microbatches
(
args
.
consumed_train_samples
)
args
.
curr_iteration
=
iteration
loss_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
\
...
...
@@ -700,7 +725,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator
,
model
,
optimizer
,
opt_param_scheduler
)
opt_param_scheduler
,
config
)
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
...
...
@@ -730,7 +756,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
evaluate_and_print_results
(
prefix
,
forward_step_func
,
valid_data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
False
)
config
,
False
)
# Checkpointing
saved_checkpoint
=
False
...
...
@@ -772,6 +798,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
if
args
.
profile
and
\
iteration
==
args
.
profile_step_end
and
\
torch
.
distributed
.
get_rank
()
in
args
.
profile_ranks
:
torch
.
cuda
.
cudart
().
cudaProfilerStop
()
return
iteration
...
...
@@ -780,6 +810,7 @@ def evaluate(forward_step_func,
data_iterator
,
model
,
process_non_loss_data_func
,
config
,
verbose
=
False
):
"""Evaluation."""
args
=
get_args
()
...
...
@@ -793,25 +824,33 @@ def evaluate(forward_step_func,
total_loss_dict
=
{}
# make validation batch size independent from training batch size
eval_batch_size
=
args
.
global_batch_size
eval_num_microbatches
=
eval_batch_size
//
\
(
args
.
micro_batch_size
*
args
.
data_parallel_size
)
with
torch
.
no_grad
():
iteration
=
0
if
verbose
:
print_rank_0
(
f
'Evaluating on
{
args
.
eval_iters
*
eval_batch_size
}
samples'
)
while
iteration
<
args
.
eval_iters
:
iteration
+=
1
if
verbose
and
iteration
%
args
.
log_interval
==
0
:
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
if
verbose
:
print_rank_0
(
f
'Evaluating iter
{
iteration
}
/
{
args
.
eval_iters
}
'
)
forward_backward_func
=
get_forward_backward_func
()
# Don't care about timing during evaluation
config
.
timers
=
None
loss_dicts
=
forward_backward_func
(
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get
_num_microbatches
()
,
dtype
=
args
.
params_dtype
,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden
_size
)
,
sequence_parallel
=
args
.
sequence_parallel
,
forward_only
=
True
,
timers
=
None
)
num_microbatches
=
eval
_num_microbatches
,
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch
_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
True
)
config
.
timers
=
get_timers
(
)
# Empty unused memory
if
args
.
empty_unused_memory_level
>=
1
:
...
...
@@ -824,35 +863,44 @@ def evaluate(forward_step_func,
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
loss_dict
[
key
]
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
micro_batch_size
\
*
get_num_microbatches
()
args
.
consumed_valid_samples
+=
eval_batch_size
collected_non_loss_data
=
None
if
process_non_loss_data_func
is
not
None
and
is_last_rank
():
collected_non_loss_data
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
forward_step_func
=
forward_step_func
,
data_iterator
=
data_iterator
,
model
=
model
,
num_microbatches
=
get_num_microbatches
(),
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
decoder_seq_length
=
args
.
decoder_seq_length
,
forward_only
=
True
,
collect_non_loss_data
=
True
)
# Move model back to the train mode.
for
model_module
in
model
:
model_module
.
train
()
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
get
_num_microbatches
()
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
eval
_num_microbatches
return
total_loss_dict
,
collected_non_loss_data
def
evaluate_and_print_results
(
prefix
,
forward_step_func
,
data_iterator
,
model
,
iteration
,
process_non_loss_data_func
,
verbose
=
False
):
iteration
,
process_non_loss_data_func
,
config
,
verbose
=
False
,
write_to_tensorboard
=
True
):
"""Helper function to evaluate and dump results on screen."""
args
=
get_args
()
writer
=
get_tensorboard_writer
()
if
write_to_tensorboard
:
writer
=
get_tensorboard_writer
()
else
:
writer
=
None
total_loss_dict
,
collected_non_loss_data
=
evaluate
(
forward_step_func
,
data_iterator
,
model
,
process_non_loss_data_func
,
verbose
)
process_non_loss_data_func
,
config
,
verbose
)
string
=
' validation loss at {} | '
.
format
(
prefix
)
for
key
in
total_loss_dict
:
string
+=
'{} value: {:.6E} | '
.
format
(
key
,
total_loss_dict
[
key
].
item
())
...
...
@@ -886,9 +934,35 @@ def cyclic_iter(iter):
yield
x
def
build_train_valid_test_datasets
(
build_train_valid_test_datasets_provider
):
"""Build pretraining datasets."""
args
=
get_args
()
# Number of train/valid/test samples.
if
args
.
train_samples
:
train_samples
=
args
.
train_samples
else
:
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_samples
,
eval_iters
*
args
.
global_batch_size
,
test_iters
*
args
.
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
# Build the datasets.
return
build_train_valid_test_datasets_provider
(
train_val_test_num_samples
)
def
build_train_valid_test_data_loaders
(
build_train_valid_test_datasets_provider
):
"""XXX"""
"""Build pretraining data loaders."""
args
=
get_args
()
(
train_dataloader
,
valid_dataloader
,
test_dataloader
)
=
(
None
,
None
,
None
)
...
...
@@ -908,31 +982,18 @@ def build_train_valid_test_data_loaders(
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
# Number of train/valid/test samples.
if
args
.
train_samples
:
train_samples
=
args
.
train_samples
else
:
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
args
.
eval_iters
test_iters
=
args
.
eval_iters
train_val_test_num_samples
=
[
train_samples
,
eval_iters
*
args
.
global_batch_size
,
test_iters
*
args
.
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
# Build the datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
(
train_val_test_num_samples
)
# Build datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
build_train_valid_test_datasets_provider
)
# Build dataloders.
train_dataloader
=
build_pretraining_data_loader
(
train_ds
,
args
.
consumed_train_samples
)
valid_dataloader
=
build_pretraining_data_loader
(
valid_ds
,
args
.
consumed_valid_samples
)
if
args
.
skip_train
:
valid_dataloader
=
build_pretraining_data_loader
(
valid_ds
,
0
)
else
:
valid_dataloader
=
build_pretraining_data_loader
(
valid_ds
,
args
.
consumed_valid_samples
)
test_dataloader
=
build_pretraining_data_loader
(
test_ds
,
0
)
# Flags to know if we need to do training/validation/testing.
...
...
@@ -958,6 +1019,7 @@ def build_train_valid_test_data_loaders(
def
build_train_valid_test_data_iterators
(
build_train_valid_test_datasets_provider
):
"""Build pretraining data iterators."""
args
=
get_args
()
...
...
pretrain_bert.py
View file @
051f58f1
...
...
@@ -16,6 +16,7 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from
megatron.model
import
BertModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.arguments
import
core_transformer_config_from_args
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
...
@@ -24,8 +25,10 @@ def model_provider(pre_process=True, post_process=True):
print_rank_0
(
'building BERT model ...'
)
args
=
get_args
()
config
=
core_transformer_config_from_args
(
args
)
num_tokentypes
=
2
if
args
.
bert_binary_head
else
0
model
=
BertModel
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
args
.
bert_binary_head
,
parallel_output
=
True
,
...
...
@@ -119,8 +122,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
binary_head
=
args
.
bert_binary_head
)
...
...
pretrain_gpt.py
View file @
051f58f1
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT"""
...
...
@@ -15,12 +15,15 @@ from megatron.model import GPTModel
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.arguments
import
core_transformer_config_from_args
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
config
=
core_transformer_config_from_args
(
get_args
())
model
=
GPTModel
(
config
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
...
...
@@ -104,7 +107,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
skip_warmup
=
(
not
args
.
mmap_warmup
),
train_data_prefix
=
args
.
train_data_path
,
valid_data_prefix
=
args
.
valid_data_path
,
test_data_prefix
=
args
.
test_data_path
)
test_data_prefix
=
args
.
test_data_path
,
data_cache_path
=
args
.
data_cache_path
)
print_rank_0
(
"> finished creating GPT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
...
...
@@ -112,8 +116,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
}
)
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
pretrain_gpt_core.py
0 → 100644
View file @
051f58f1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT"""
import
torch
from
functools
import
partial
from
megatron
import
get_args
from
megatron.arguments
import
core_transformer_config_from_args
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron.core
import
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.core.models.gpt
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
config
=
core_transformer_config_from_args
(
args
)
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
config
=
config
,
vocab_size
=
args
.
padded_vocab_size
,
max_sequence_length
=
args
.
max_position_embeddings
,
pre_process
=
pre_process
,
post_process
=
post_process
,
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
,
parallel_output
=
True
,
share_embeddings_and_output_weights
=
not
args
.
untie_embeddings_and_output_weights
,
position_embedding_type
=
args
.
position_embedding_type
,
rotary_percent
=
args
.
rotary_percent
)
return
model
def
get_batch
(
data_iterator
):
"""Generate a batch"""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Items and their type.
keys
=
[
'text'
]
datatype
=
torch
.
int64
# Broadcast data.
if
data_iterator
is
not
None
:
data
=
next
(
data_iterator
)
else
:
data
=
None
data_b
=
tensor_parallel
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
tokens_
=
data_b
[
'text'
].
long
()
labels
=
tokens_
[:,
1
:].
contiguous
()
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
# Get the masks and postition ids.
attention_mask
,
loss_mask
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
def
loss_func
(
loss_mask
,
output_tensor
):
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
'batch-generator'
,
log_level
=
2
).
start
()
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
=
get_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
labels
=
labels
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid, and test datasets."""
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets '
'for GPT ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
seq_length
=
args
.
seq_length
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
train_data_prefix
=
args
.
train_data_path
,
valid_data_prefix
=
args
.
valid_data_path
,
test_data_prefix
=
args
.
test_data_path
,
data_cache_path
=
args
.
data_cache_path
)
print_rank_0
(
"> finished creating GPT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
}
)
pretrain_retro.py
View file @
051f58f1
...
...
@@ -14,7 +14,7 @@ from megatron.core.enums import ModelType
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
tools.retro.
pretraining
.retro_dataset
import
get_retro_datasets
from
tools.retro.
query
.retro_dataset
import
get_retro_datasets
from
pretrain_gpt
import
(
loss_func
,
...
...
@@ -96,9 +96,9 @@ def forward_step(data_iterator, model):
timers
(
'batch-generator'
).
stop
()
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
ret_input_ids
=
neighbor_tokens
,
ret_position_ids
=
neighbor_position_ids
,
ret_attn_mask
=
neighbor_attention_mask
,
ret
riever
_input_ids
=
neighbor_tokens
,
ret
riever
_position_ids
=
neighbor_position_ids
,
ret
riever
_attn_mask
=
neighbor_attention_mask
,
labels
=
labels
)
return
output_tensor
,
partial
(
loss_func
,
loss_mask
)
...
...
@@ -115,6 +115,9 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
retro_decoder
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
,
'retro_add_retriever'
:
True
})
pretrain_t5.py
View file @
051f58f1
...
...
@@ -17,6 +17,7 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from
megatron.model
import
T5Model
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.arguments
import
core_transformer_config_from_args
"""
...
...
@@ -60,7 +61,9 @@ def model_provider(pre_process=True, post_process=True,
"""Build the model."""
print_rank_0
(
'building T5 model ...'
)
model
=
T5Model
(
num_tokentypes
=
0
,
config
=
core_transformer_config_from_args
(
get_args
())
model
=
T5Model
(
config
=
config
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
,
...
...
@@ -144,8 +147,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
train_valid_test_num_samples
=
train_val_test_num_samples
,
max_seq_length
=
args
.
encoder_seq_length
,
max_seq_length_dec
=
args
.
decoder_seq_length
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
dataset_type
=
't5'
)
...
...
pretrain_vision_classify.py
View file @
051f58f1
...
...
@@ -12,16 +12,18 @@ from megatron.model.vision.classification import VitClassificationModel
from
megatron.model.vision.classification
import
MitClassificationModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.arguments
import
core_transformer_config_from_args
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
config
=
core_transformer_config_from_args
(
args
)
if
args
.
vision_backbone_type
==
'vit'
:
print_rank_0
(
"building VIT model ..."
)
model
=
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
model
=
VitClassificationModel
(
config
=
config
,
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
elif
args
.
vision_backbone_type
==
'mit'
:
...
...
pretrain_vision_dino.py
View file @
051f58f1
...
...
@@ -16,10 +16,12 @@ from megatron.utils import average_losses_across_data_parallel_group, unwrap_mod
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.arguments
import
core_transformer_config_from_args
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
return
DINOPretrainModel
(
pre_process
=
pre_process
,
post_process
=
post_process
)
config
=
core_transformer_config_from_args
(
get_args
())
return
DINOPretrainModel
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
)
def
get_batch
(
data_iterator
):
"""Build the batch."""
...
...
pretrain_vision_inpaint.py
View file @
051f58f1
...
...
@@ -13,12 +13,15 @@ from megatron.model.vision.inpainting import MitInpaintingModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.vision.metrics
import
SSIM
,
PSNR
from
megatron.arguments
import
core_transformer_config_from_args
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
config
=
core_transformer_config_from_args
(
args
)
if
args
.
vision_backbone_type
==
'vit'
:
model
=
VitInpaintingModel
(
pre_process
=
pre_process
,
model
=
VitInpaintingModel
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
)
elif
args
.
vision_backbone_type
==
'mit'
:
model
=
MitInpaintingModel
(
pre_process
=
pre_process
,
...
...
pyproject.toml
0 → 100644
View file @
051f58f1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
[tool.isort]
profile
=
"black"
# black-compatible
line_length
=
100
# should match black parameters
py_version
=
38
# python 3.8 as a target version
known_first_party
=
["megatron"]
# FIRSTPARTY section
known_third_party
=
["transformer_engine"]
# THIRDPARTY section
sections
=
[
"FUTURE"
,
"STDLIB"
,
"THIRDPARTY"
,
"FIRSTPARTY"
,
"LOCALFOLDER"
]
default_section
=
"THIRDPARTY"
extend_skip
=
["setup.py"]
[tool.black]
line_length
=
100
skip_string_normalization
=
true
# recongized by future versions, disallows to reformat code with incompatible versions
# Matches NeMO version so people working on both codebases don't need two different version of black installed
required_version
=
"19.10b0"
requirements.txt
0 → 100644
View file @
051f58f1
datasets
nltk
numpy
parameterized
pybind11
regex
six
sentencepiece
tensorboard
transformers
ninja
mpi4py
einops
setup.py
View file @
051f58f1
from
setuptools
import
setup
,
find_packages
setup
(
name
=
"megatron.core"
,
version
=
"0.1"
,
description
=
"Core components of Megatron."
,
"""Setup for pip package."""
import
importlib.util
import
os
import
setuptools
spec
=
importlib
.
util
.
spec_from_file_location
(
'package_info'
,
'megatron/core/package_info.py'
)
package_info
=
importlib
.
util
.
module_from_spec
(
spec
)
spec
.
loader
.
exec_module
(
package_info
)
__contact_emails__
=
package_info
.
__contact_emails__
__contact_names__
=
package_info
.
__contact_names__
__description__
=
package_info
.
__description__
__download_url__
=
package_info
.
__download_url__
__homepage__
=
package_info
.
__homepage__
__keywords__
=
package_info
.
__keywords__
__license__
=
package_info
.
__license__
__package_name__
=
package_info
.
__package_name__
__repository_url__
=
package_info
.
__repository_url__
__version__
=
package_info
.
__version__
if
os
.
path
.
exists
(
'megatron/core/README.md'
):
with
open
(
"megatron/core/README.md"
,
"r"
,
encoding
=
'utf-8'
)
as
fh
:
long_description
=
fh
.
read
()
long_description_content_type
=
"text/markdown"
else
:
long_description
=
'See '
+
__homepage__
long_description_content_type
=
"text/plain"
###############################################################################
# Dependency Loading #
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
def
req_file
(
filename
,
folder
=
"megatron/core"
):
with
open
(
os
.
path
.
join
(
folder
,
filename
),
encoding
=
'utf-8'
)
as
f
:
content
=
f
.
readlines
()
# you may also want to remove whitespace characters
# Example: `\n` at the end of each line
return
[
x
.
strip
()
for
x
in
content
]
install_requires
=
req_file
(
"requirements.txt"
)
###############################################################################
setuptools
.
setup
(
name
=
__package_name__
,
# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version
=
__version__
,
description
=
__description__
,
long_description
=
long_description
,
long_description_content_type
=
long_description_content_type
,
# The project's main homepage.
url
=
__repository_url__
,
download_url
=
__download_url__
,
# Author details
author
=
__contact_names__
,
author_email
=
__contact_emails__
,
# maintainer Details
maintainer
=
__contact_names__
,
maintainer_email
=
__contact_emails__
,
# The licence under which the project is released
license
=
__license__
,
classifiers
=
[
# How mature is this project? Common values are
# 1 - Planning
# 2 - Pre-Alpha
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
# 6 - Mature
# 7 - Inactive
'Development Status :: 5 - Production/Stable'
,
# Indicate who your project is intended for
'Intended Audience :: Developers'
,
'Intended Audience :: Science/Research'
,
'Intended Audience :: Information Technology'
,
# Indicate what your project relates to
'Topic :: Scientific/Engineering'
,
'Topic :: Scientific/Engineering :: Mathematics'
,
'Topic :: Scientific/Engineering :: Image Recognition'
,
'Topic :: Scientific/Engineering :: Artificial Intelligence'
,
'Topic :: Software Development :: Libraries'
,
'Topic :: Software Development :: Libraries :: Python Modules'
,
'Topic :: Utilities'
,
# Pick your license as you wish (should match "license" above)
'License :: OSI Approved :: BSD License'
,
# Supported python versions
'Programming Language :: Python :: 3'
,
'Programming Language :: Python :: 3.8'
,
'Programming Language :: Python :: 3.9'
,
# Additional Setting
'Environment :: Console'
,
'Natural Language :: English'
,
'Operating System :: OS Independent'
,
],
packages
=
find_packages
(
include
=
(
"megatron.core"
)
)
include
=
[
'megatron.core'
,
'megatron.core.*'
],
),
# Add in any packaged data.
include_package_data
=
True
,
# PyPI package information.
keywords
=
__keywords__
,
)
single.sh
0 → 100644
View file @
051f58f1
#!/bin/bash
# This example script is contributed by external user https://github.com/nrailgun
set
-ex
######################################
#####################################
export
CUDA_DEVICE_MAX_CONNECTIONS
=
1
export
HSA_FORCE_FINE_GRAIN_PCIE
=
1
export
OMP_NUM_THREADS
=
1
export
NCCL_P2P_LEVEL
=
5
lrank
=
$OMPI_COMM_WORLD_LOCAL_RANK
RANK
=
$OMPI_COMM_WORLD_RANK
WORLD_SIZE
=
$OMPI_COMM_WORLD_SIZE
export
NCCL_IB_TIMEOUT
=
22
# Change the below configurations here
BASE_PATH
=
./tmp
DATASET_1
=
"./dataset/my-gpt2_text_document"
DATASET
=
"1
${
DATASET_1
}
"
CHECKPOINT_PATH
=
./tmp
TP
=
4
PP
=
1
HIDDEN_SIZE
=
4096
NUM_LAYERS
=
32
NUM_HEADS
=
32
SEQ_LENGTH
=
4096
VOCAB_PATH
=
./gpt2-vocab.json
MERGE_PATH
=
./gpt2-merges.txt
MICRO_BATCH_SIZE
=
1
GLOBAL_BATCH_SIZE
=
60
TRAIN_STEPS
=
250000
LR
=
3e-4
MIN_LR
=
3e-5
LR_WARMUP_STEPS
=
2000
WEIGHT_DECAY
=
0.1
GRAD_CLIP
=
1
APP
=
"python3 -u pretrain_gpt.py
\
--tensor-model-parallel-size
$TP
\
--pipeline-model-parallel-size
$PP
\
--num-layers
$NUM_LAYERS
\
--hidden-size
$HIDDEN_SIZE
\
--num-attention-heads
$NUM_HEADS
\
--micro-batch-size
$MICRO_BATCH_SIZE
\
--global-batch-size
$GLOBAL_BATCH_SIZE
\
--seq-length
$SEQ_LENGTH
\
--max-position-embeddings
$SEQ_LENGTH
\
--train-iters
$TRAIN_STEPS
\
--save
$CHECKPOINT_PATH
\
--load
$CHECKPOINT_PATH
\
--data-path
$DATASET
\
--data-impl mmap
\
--split 949,50,1
\
--distributed-backend nccl
\
--lr
$LR
\
--lr-decay-style cosine
\
--min-lr
$MIN_LR
\
--weight-decay
$WEIGHT_DECAY
\
--clip-grad
$GRAD_CLIP
\
--lr-warmup-iters
$LR_WARMUP_STEPS
\
--optimizer adam
\
--adam-beta1 0.9
\
--adam-beta2 0.95
\
--log-interval 1
\
--vocab-file
${
VOCAB_PATH
}
\
--merge-file
${
MERGE_PATH
}
\
--tokenizer-type GPT2BPETokenizer
\
--save-interval 1000
\
--eval-interval 1000
\
--eval-iters 1000
\
--fp16
\
--recompute-activations
\
--disable-bias-linear
\
--no-gradient-accumulation-fusion
\
--rank
${
RANK
}
\
--world_size
${
WORLD_SIZE
}
\
--dist_url tcp://
${
1
}
:34566
\
--num-workers 2
\
"
case
${
lrank
}
in
[
0]
)
export
HIP_VISIBLE_DEVICES
=
0,1,2,3
${
APP
}
;;
[
1]
)
export
HIP_VISIBLE_DEVICES
=
0,1,2,3
${
APP
}
;;
[
2]
)
export
HIP_VISIBLE_DEVICES
=
0,1,2,3
${
APP
}
;;
[
3]
)
export
HIP_VISIBLE_DEVICES
=
0,1,2,3
${
APP
}
;;
esac
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment