Commit 051f58f1 authored by liangjing's avatar liangjing
Browse files

v1

parent 0024a5c6
Pipeline #829 passed with stage
# Contributing to Megatron-LM
This document outlines the processes and policies for issues and pull requests by non-NVIDIA contributors to the Megatron-LM github repository.
Everyone is welcome to contribute to the project but development of Megatron-LM continues internally at NVIDIA. When contributing it important to ensure that changes are in line with the project direction. Small changes to fix bugs are welcomed and appreciated. If proposing large architectural changes or changes for stylistic reasons open an issue first so we can discuss it.
PRs will first be pulled into NVIDIA's internal Megatron-LM repo and then pushed back out to the open github repo with proper credit given to the committers.
## Issue policy
Please do file any bugs you find, keeping the following in mind:
- If filing a bug, i.e. you have found something that doesn't work as expected, use the BUG template.
- If you've found a regression in speed or accuracy use the REGRESSION template.
- If you are requesting a new feature or modification of an existing feature use the ENHANCEMENT template.
- If opening an issue to ask a question no template is needed but please make your question as clear and concise as possible.
- One issue per bug. Putting multiple things in the same issue makes both discussion and completion unnecessarily complicated.
- Your bug is mostly likely to get attention from the development team quickly if we can easily reproduce it.
- Use proper spelling, grammar, and punctuation.
- Write in an authoritative and technical tone.
## Code submission policy
Here are some dos & don'ts to try and stick to:
### Do:
- Format new code in a style that is consistent with the file being changed. Megatron-LM doesn't (yet) have a style guide or enforced formatting.
- Split your changes into separate, atomic commits i.e. A commit per feature or fix.
- Make sure your commits are rebased on the master branch.
- Write the commit message subject line in the imperative mood ("Change the default argument for X", not "Changed the default argument for X").
- Write your commit messages in proper English, with care and punctuation.
- Check the spelling of your code, comments and commit messages.
### Don't:
- Submit code that's incompatible with the project licence.
- Touch anything outside the stated scope of the PR. This includes formatting changes to code not relevant to the PR.
- Iterate excessively on your design across multiple commits.
- Include commented-out code.
- Attempt large architectural changes without first opening an issue to discuss.
## Issue and Pull Request Q&A (Updated Jul 2023)
### I've submitted an issue and PR. When can I expect to get some feedback?
Megatron-LM is developed and maintained by a small team of researchers. We will endeavour to read and acknowledge all new issues and PRs within a week. A few rules of thumb:
- Reproducible bugs/regressions and bug/regression fixes are likely to get the attention of maintainers the quickest.
- Issues requesting an enhancement may only recieve acknowlegement that they've been read and may be closed with a "wontfix" label if they're not inline with the project direction. If they are acknowledged and remain open you can assume the maintainers agree they're a desirable feature.
- Support requests, i.e. requests for help running the code, have the lowest priority and will be responded to as maintainer time permits.
### If my issue or PR isn't getting attention, how long should I wait before pinging one of the project maintainers?
One week if there is no acknowledgement of the intial request.
### Who are the project maintainers I should ping?
The corresponding maintainers at this time are @jaredcasper and @jon-barker.
### Is there a policy for issues and PRs that haven't been touched in X days? Should they be closed?
Yes, starting in July 2023 we have a bot that will mark untouched PRs as "stale" after 60 days.
We have a long backlog of issues and PRs dating back 3.5 years. We are trying to triage these now by working backwards. Older issues we believe may still be relevant may recieve a request to re-test them with the latest code. If there's no response they may be closed. Again, if you they should be re-opened then just respond with a comment to that effect.
Thank-you!
\ No newline at end of file
...@@ -289,88 +289,3 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE ...@@ -289,88 +289,3 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE SOFTWARE
--------------- NVIDIA Source Code License for SegFormer -----------------
1. Definitions
“Licensor” means any person or entity that distributes its Work.
“Software” means the original work of authorship made available under this
License.
“Work” means the Software and any additions to or derivative works of the
Software that are made available under this License.
The terms “reproduce,” “reproduction,” “derivative works,” and
“distribution” have the meaning as provided under U.S. copyright law;
provided, however, that for the purposes of this License, derivative works
shall not include works that remain separable from, or merely link
(or bind by name) to the interfaces of, the Work.
Works, including the Software, are “made available” under this License by
including in or with the Work either (a) a copyright notice referencing
the applicability of this License to the Work, or (b) a copy of this License.
2. License Grant
2.1 Copyright Grant. Subject to the terms and conditions of this License,
each Licensor grants to you a perpetual, worldwide, non-exclusive,
royalty-free, copyright license to reproduce, prepare derivative works of,
publicly display, publicly perform, sublicense and distribute its Work
and any resulting derivative works in any form.
3. Limitations
3.1 Redistribution. You may reproduce or distribute the Work only if
(a) you do so under this License, (b) you include a complete copy of this
License with your distribution, and (c) you retain without modification any
copyright, patent, trademark, or attribution notices that are present
in the Work.
3.2 Derivative Works. You may specify that additional or different terms
apply to the use, reproduction, and distribution of your derivative works
of the Work (“Your Terms”) only if (a) Your Terms provide that the use
limitation in Section 3.3 applies to your derivative works, and (b) you
identify the specific derivative works that are subject to Your Terms.
Notwithstanding Your Terms, this License (including the redistribution
requirements in Section 3.1) will continue to apply to the Work itself.
3.3 Use Limitation. The Work and any derivative works thereof only may
be used or intended for use non-commercially. Notwithstanding the
foregoing, NVIDIA and its affiliates may use the Work and any derivative
works commercially. As used herein, “non-commercially” means for research
or evaluation purposes only.
3.4 Patent Claims. If you bring or threaten to bring a patent claim against
any Licensor (including any claim, cross-claim or counterclaim in a lawsuit)
to enforce any patents that you allege are infringed by any Work, then
your rights under this License from such Licensor (including the grant
in Section 2.1) will terminate immediately.
3.5 Trademarks. This License does not grant any rights to use any Licensor’s
or its affiliates’ names, logos, or trademarks, except as necessary to
reproduce the notices described in this License.
3.6 Termination. If you violate any term of this License, then your rights
under this License (including the grant in Section 2.1) will terminate
immediately.
4. Disclaimer of Warranty.
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT.
YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
5. Limitation of Liability.
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT
OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
This diff is collapsed.
This diff is collapsed.
...@@ -20,7 +20,6 @@ python tasks/main.py \ ...@@ -20,7 +20,6 @@ python tasks/main.py \
--num-attention-heads 12 \ --num-attention-heads 12 \
--tensor-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--micro-batch-size 128 \ --micro-batch-size 128 \
--activations-checkpoint-method uniform \
--seq-length 512 \ --seq-length 512 \
--max-position-embeddings 512 \ --max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \ --load ${CHECKPOINT_PATH} \
......
...@@ -29,7 +29,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--batch-size 8 \ --batch-size 8 \
--activations-checkpoint-method uniform \
--seq-length 1024 \ --seq-length 1024 \
--max-position-embeddings 1024 \ --max-position-embeddings 1024 \
--log-interval 10 \ --log-interval 10 \
......
...@@ -29,7 +29,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--micro-batch-size 8 \ --micro-batch-size 8 \
--activations-checkpoint-method uniform \
--lr 5.0e-5 \ --lr 5.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-warmup-fraction 0.065 \ --lr-warmup-fraction 0.065 \
......
...@@ -29,7 +29,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \ ...@@ -29,7 +29,6 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--hidden-size 1024 \ --hidden-size 1024 \
--num-attention-heads 16 \ --num-attention-heads 16 \
--micro-batch-size 4 \ --micro-batch-size 4 \
--activations-checkpoint-method uniform \
--lr 1.0e-5 \ --lr 1.0e-5 \
--lr-decay-style linear \ --lr-decay-style linear \
--lr-warmup-fraction 0.06 \ --lr-warmup-fraction 0.06 \
......
...@@ -41,15 +41,14 @@ options=" \ ...@@ -41,15 +41,14 @@ options=" \
--save-interval 1000 \ --save-interval 1000 \
--save <PATH TO CHECKPOINTS DIRECTORY> \ --save <PATH TO CHECKPOINTS DIRECTORY> \
--load <PATH TO CHECKPOINTS DIRECTORY> \ --load <PATH TO CHECKPOINTS DIRECTORY> \
--split 98,2,0 \ --split 98,2,0 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--weight-decay 0.1 \ --weight-decay 0.1 \
--adam-beta1 0.9 \ --adam-beta1 0.9 \
--adam-beta2 0.95 \ --adam-beta2 0.95 \
--init-method-std 0.006 \ --init-method-std 0.006 \
--tensorboard-dir <TENSORBOARD DIRECTORY> \ --tensorboard-dir <TENSORBOARD DIRECTORY> \
--fp16 \ --fp16 "
--activations-checkpoint-method uniform "
run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}" run_cmd="python -u ${DIR}/pretrain_gpt.py $@ ${options}"
......
...@@ -9,6 +9,11 @@ scripts use [Slurm](https://slurm.schedmd.com/documentation.html) with the ...@@ -9,6 +9,11 @@ scripts use [Slurm](https://slurm.schedmd.com/documentation.html) with the
schedulers as well. schedulers as well.
## Git commit
To replicate these results use Megatron-LM commit: 6985e58938d40ad91ac07b0fddcfad8132e1447e
## Setup ## Setup
All the cluster-dependent variables are in [`CONFIG.sh`](./CONFIG.sh). Please All the cluster-dependent variables are in [`CONFIG.sh`](./CONFIG.sh). Please
......
...@@ -3,14 +3,17 @@ ...@@ -3,14 +3,17 @@
"""Megatron arguments.""" """Megatron arguments."""
import argparse import argparse
import dataclasses
import json import json
import os import os
import torch import torch
import types import types
import torch.nn.functional as F
from megatron.global_vars import set_retro_args, get_retro_args from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path from tools.retro.utils import get_args_path as get_retro_args_path
from megatron.core.transformer import TransformerConfig
def parse_args(extra_args_provider=None, ignore_unknown_args=False): def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments.""" """Parse all arguments."""
...@@ -47,9 +50,9 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): ...@@ -47,9 +50,9 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
args = parser.parse_args() args = parser.parse_args()
# Args from environment # Args from environment
args.rank = int(os.getenv('RANK', '0')) #args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1')) #args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args return args
def validate_args(args, defaults={}): def validate_args(args, defaults={}):
...@@ -71,7 +74,7 @@ def validate_args(args, defaults={}): ...@@ -71,7 +74,7 @@ def validate_args(args, defaults={}):
# Checks. # Checks.
model_parallel_size = args.pipeline_model_parallel_size * \ model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\ assert args.world_size % model_parallel_size == 0, 'world size ({}) is not'\
' divisible by tensor parallel size ({}) times pipeline parallel ' \ ' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size, 'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size) args.pipeline_model_parallel_size)
...@@ -102,12 +105,10 @@ def validate_args(args, defaults={}): ...@@ -102,12 +105,10 @@ def validate_args(args, defaults={}):
del args.model_parallel_size del args.model_parallel_size
if args.checkpoint_activations: if args.checkpoint_activations:
args.recompute_granularity = 'full'
args.recompute_method = 'uniform'
if args.rank == 0: if args.rank == 0:
print('--checkpoint-activations is no longer valid, ' print('--checkpoint-activations is no longer valid, use --recompute-activations, '
'use --recompute-granularity and --recompute-method instead. ' 'or, for more control, --recompute-granularity and --recompute-method.')
'Defaulting to recompute-granularity=full and recompute-method=uniform.') exit()
del args.checkpoint_activations del args.checkpoint_activations
if args.recompute_activations: if args.recompute_activations:
...@@ -119,7 +120,7 @@ def validate_args(args, defaults={}): ...@@ -119,7 +120,7 @@ def validate_args(args, defaults={}):
# For default to be valid, it should not be provided in the # For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by # arguments that are passed to the program. We check this by
# ensuring the arg is set to None. # ensuring the arg is set to None.
if getattr(args, key) is not None: if getattr(args, key, None) is not None:
if args.rank == 0: if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \ print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key], with {key}:{v2}'.format(key=key, v=defaults[key],
...@@ -314,23 +315,11 @@ def validate_args(args, defaults={}): ...@@ -314,23 +315,11 @@ def validate_args(args, defaults={}):
assert args.recompute_method is not None, \ assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\ 'for distributed recompute activations to work you '\
'need to use a recompute method ' 'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \ assert (TORCH_MAJOR, TORCH_MINOR) >= (1, 10), \
'distributed recompute activations are supported for pytorch ' \ 'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# Tranformer-Engine/FP8 related checking
if args.fp8_e4m3 or args.fp8_hybrid:
assert args.transformer_impl == 'transformer_engine', \
'transformer-engine required for fp8 training and inference'
assert not (args.fp8_e4m3 and args.fp8_hybrid), \
'cannot train with both fp8 e4m3 and hybrid formatting'
if args.fp16:
assert args.transformer_impl == 'local', \
'transformer-engine not yet approved for fp16 training and inference'
if args.recompute_granularity == 'selective': if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \ assert args.recompute_method is None, \
'recompute method is not yet supported for ' \ 'recompute method is not yet supported for ' \
...@@ -361,17 +350,36 @@ def validate_args(args, defaults={}): ...@@ -361,17 +350,36 @@ def validate_args(args, defaults={}):
if not args.add_bias_linear: if not args.add_bias_linear:
args.bias_gelu_fusion = False args.bias_gelu_fusion = False
# Load retro args. # Retro checks.
if args.retro_workdir: if args.retro_add_retriever:
# Sequence parallelism unsupported.
assert not args.sequence_parallel, \
"retro currently does not support sequence parallelism."
# Pipeline parallelism unsupported.
assert args.pipeline_model_parallel_size == 1, \
"retro currently does not support pipeline parallelism."
# Load retro args.
retro_args_path = get_retro_args_path(args.retro_workdir) retro_args_path = get_retro_args_path(args.retro_workdir)
if os.path.exists(retro_args_path): assert os.path.exists(retro_args_path), "retro workdir missing args.json"
with open(retro_args_path) as f: with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f)) retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \ retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \ args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length retro_args.retro_gpt_chunk_length
set_retro_args(retro_args) set_retro_args(retro_args)
# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope'
# Would just need to add 'NoPE' as a position_embedding_type to support this, but for now
# don't allow it to keep things simple
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
# Print arguments. # Print arguments.
_print_args("arguments", args) _print_args("arguments", args)
...@@ -400,31 +408,63 @@ def _print_args(title, args): ...@@ -400,31 +408,63 @@ def _print_args(title, args):
def _check_arg_is_not_none(args, arg): def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg) assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def core_transformer_config_from_args(args):
# Translate args to core transformer configuration
kw_args = {}
for f in dataclasses.fields(TransformerConfig):
if hasattr(args, f.name):
kw_args[f.name] = getattr(args, f.name)
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = args.params_dtype
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
kw_args['bias_gelu_fusion'] = False
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
if args.group_query_attention:
kw_args['num_query_groups'] = args.num_query_groups
else:
kw_args['num_query_groups'] = None
return TransformerConfig(**kw_args)
def _add_transformer_engine_args(parser): def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine') group = parser.add_argument_group(title='Transformer-Engine')
group.add_argument('--fp8-e4m3', action='store_true', group.add_argument('--fp8-format', default=None,
help='E4M3 TransformerLayer', dest='fp8_e4m3') choices=['e4m3', 'hybrid'],
group.add_argument('--fp8-hybrid', action='store_true', help='Which fp8 format scheme to use for FP8 tensors in the forward and backward pass',
help='Hybrid FP8 TransformerLayer', dest='fp8_hybrid') dest='fp8')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad')
group.add_argument('--fp8-margin', type=int, default=0, group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8', dest='fp8_margin') help='Scaling margin for fp8',
dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1, group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8', dest='fp8_interval') help='Scaling update interval for fp8',
group.add_argument('--transformer-impl', default='local', dest='fp8_interval')
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.',
dest='transformer_impl')
group.add_argument('--fp8-amax-history-len', type=int, default=1, group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor', help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len') dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent', group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'], choices=['most_recent', 'max'],
help='Algorithm for computing amax from history', help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo') dest='fp8_amax_compute_algo')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs',
dest='fp8_wgrad')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.',
dest='transformer_impl')
group.add_argument('--normalization', default='LayerNorm',
choices=['LayerNorm', 'RMSNorm'],
help='Which normalization technique to use.',
dest='normalization')
return parser return parser
...@@ -518,16 +558,26 @@ def _add_network_size_args(parser): ...@@ -518,16 +558,26 @@ def _add_network_size_args(parser):
'attention. This is set to ' 'attention. This is set to '
' args.hidden_size // args.num_attention_heads ' ' args.hidden_size // args.num_attention_heads '
'if not provided.') 'if not provided.')
group.add_argument('--group-query-attention', action='store_true',
help='Use group-query attention.')
group.add_argument('--num-query-groups', type=int, default=1)
group.add_argument('--max-position-embeddings', type=int, default=None, group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. ' help='Maximum number of position embeddings to use. '
'This is the size of position embedding.') 'This is the size of position embedding.')
group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
choices=['learned_absolute', 'rope'],
help='Position embedding type.')
group.add_argument('--use-rotary-position-embeddings', action='store_true', group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not') help='Use rotary positional embeddings or not. '
'Deprecated: use --position-embedding-type')
group.add_argument('--rotary-percent', type=float, default=1.0, group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%') help='Percent of rotary dimension to use, default 100%%')
group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
help='Sequence length interpolation factor for rotary embeddings.')
group.add_argument('--no-position-embedding', group.add_argument('--no-position-embedding',
action='store_false', action='store_false',
help='Disable position embedding.', help='Disable position embedding. Deprecated: use --position-embedding-type',
dest='add_position_embedding') dest='add_position_embedding')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.' help='Pad the vocab size to be divisible by this value.'
...@@ -559,6 +609,8 @@ def _add_network_size_args(parser): ...@@ -559,6 +609,8 @@ def _add_network_size_args(parser):
help='Number of Experts in Switch Transformer (None means no Switch)') help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true', group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'), help='Untie embeddings and output weights.'),
group.add_argument('--embedding-weights-in-fp32', action='store_true',
help='Cast word embedding weights to fp32 before embedding fwd.'),
return parser return parser
...@@ -713,11 +765,25 @@ def _add_training_args(parser): ...@@ -713,11 +765,25 @@ def _add_training_args(parser):
'individual Transformer layers per pipeline stage and do the ' 'individual Transformer layers per pipeline stage and do the '
'rest without any recomputing at specified granularity' 'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers') 'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1, group.add_argument('--recompute-num-layers', type=int, default=None,
help='1) uniform: the number of Transformer layers in each ' help='1) uniform: the number of Transformer layers in each '
'uniformly divided recompute unit, ' 'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers ' '2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.') 'to recompute within each pipeline stage.')
group.add_argument('--profile', action='store_true',
help='Enable nsys profiling. When using this option, nsys '
'options should be specified in commandline. An example '
'nsys commandline is `nsys profile -s none -t nvtx,cuda '
'-o <path/to/output_file> --force-overwrite true '
'--capture-range=cudaProfilerApi '
'--capture-range-end=stop`.')
group.add_argument('--profile-step-start', type=int, default=10,
help='Gloable step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12,
help='Gloable step to stop profiling.')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.')
# deprecated # deprecated
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--checkpoint-activations', action='store_true',
...@@ -830,6 +896,9 @@ def _add_learning_rate_args(parser): ...@@ -830,6 +896,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-warmup-samples', type=int, default=0, group.add_argument('--lr-warmup-samples', type=int, default=0,
help='number of samples to linearly warmup ' help='number of samples to linearly warmup '
'learning rate over.') 'learning rate over.')
group.add_argument('--lr-warmup-init', type=float, default=0.0,
help='Initial value for learning rate warmup. The '
'scheduler starts warmup from this value.')
group.add_argument('--warmup', type=int, default=None, group.add_argument('--warmup', type=int, default=None,
help='Old lr warmup argument, do not use. Use one of the' help='Old lr warmup argument, do not use. Use one of the'
'--lr-warmup-* arguments above') '--lr-warmup-* arguments above')
...@@ -941,6 +1010,10 @@ def _add_distributed_args(parser): ...@@ -941,6 +1010,10 @@ def _add_distributed_args(parser):
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage') help='Number of layers per virtual pipeline stage')
group.add_argument('--overlap-p2p-communication',
action='store_true',
help='overlap pipeline parallel communication with forward and backward chunks',
dest='overlap_p2p_comm')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
...@@ -985,6 +1058,13 @@ def _add_distributed_args(parser): ...@@ -985,6 +1058,13 @@ def _add_distributed_args(parser):
group.add_argument('--use-distributed-optimizer', action='store_true', group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.') help='Use distributed optimizer.')
group.add_argument('--rank', default=-1, type=int,
help='node rank for distributed training')
group.add_argument('--world_size', type=int, default=-1,
help='number of nodes for distributed training')
group.add_argument('--dist_url',
help='Which master node url for distributed training.')
return parser return parser
...@@ -997,6 +1077,9 @@ def _add_validation_args(parser): ...@@ -997,6 +1077,9 @@ def _add_validation_args(parser):
group.add_argument('--eval-interval', type=int, default=1000, group.add_argument('--eval-interval', type=int, default=1000,
help='Interval between running evaluation on ' help='Interval between running evaluation on '
'validation set.') 'validation set.')
group.add_argument('--skip-train', action='store_true',
default=False, help='If set, bypass the training loop, '
'optionally do evaluation for validation/test, and exit.')
return parser return parser
...@@ -1032,7 +1115,11 @@ def _add_data_args(parser): ...@@ -1032,7 +1115,11 @@ def _add_data_args(parser):
'1) a single data path, 2) multiple datasets in the' '1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight ' 'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...') 'dataset2-path ...')
group.add_argument('--data-cache-path', default=None,
help='Path to a directory to hold cached index files.')
group.add_argument('--vocab-size', type=int, default=None,
help='Size of vocab before EOD or padding.')
group.add_argument('--vocab-file', type=str, default=None, group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.') help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
...@@ -1067,12 +1154,13 @@ def _add_data_args(parser): ...@@ -1067,12 +1154,13 @@ def _add_data_args(parser):
'BertWordPieceCase', 'BertWordPieceCase',
'GPT2BPETokenizer', 'GPT2BPETokenizer',
'SentencePieceTokenizer', 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer'], 'GPTSentencePieceTokenizer',
'NullTokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None, group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.') help='Sentencepiece tokenizer model.')
group.add_argument('--data-impl', type=str, default='infer', group.add_argument('--data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'], choices=['mmap', 'infer'],
help='Implementation of indexed datasets.') help='Implementation of indexed datasets.')
group.add_argument('--reset-position-ids', action='store_true', group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.') help='Reset posistion ids after end-of-document token.')
...@@ -1190,14 +1278,14 @@ def _add_vision_args(parser): ...@@ -1190,14 +1278,14 @@ def _add_vision_args(parser):
group.add_argument('--swin-backbone-type', type=str, default='tiny', group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'], choices=['tiny', 'base', 'h3'],
help='pretraining objectives') help='pretraining objectives')
# inpainting arguments # inpainting arguments
group.add_argument('--mask-type', type=str, default='random', group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'], choices=['random', 'row'],
help='mask types') help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0, group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter') help='mask size scaling parameter')
# dino arguments # dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250, group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch') help='iterations per epoch')
......
This diff is collapsed.
Megatron Core is a library for efficient and scalable training of transformer based models.
...@@ -2,11 +2,10 @@ import megatron.core.parallel_state ...@@ -2,11 +2,10 @@ import megatron.core.parallel_state
import megatron.core.tensor_parallel import megatron.core.tensor_parallel
import megatron.core.utils import megatron.core.utils
from .inference_params import InferenceParams
from .model_parallel_config import ModelParallelConfig
# Alias parallel_state as mpu, its legacy name # Alias parallel_state as mpu, its legacy name
mpu = parallel_state mpu = parallel_state
__all__ = [ __all__ = ["parallel_state", "tensor_parallel", "utils", "InferenceParams", "ModelParallelConfig"]
"parallel_state",
"tensor_parallel",
"utils",
]
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersitentObject, ShardedTensor
from .serialization import load, load_common_state_dict, save
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception):
pass
@dataclass
class CheckpointingConfig:
""" Documents backends used in the checkpoint. """
sharded_backend: str
sharded_backend_version: int = 1
common_backend: str = 'torch'
common_backend_version: int = 1
def check_is_distributed_checkpoint(checkpoint_dir):
return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if not config_path.exists():
return None
with config_path.open() as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)
def save_config(config: CheckpointingConfig, checkpoint_dir: str):
config_path = Path(checkpoint_dir, CONFIG_FNAME)
with config_path.open('w') as f:
json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for operating with dicts and lists. """
from collections import defaultdict
from typing import Any, Callable, Iterable, Optional, Tuple, Union
import torch
def extract_matching_values(
x: Union[dict, list], predicate: Callable
) -> Tuple[Union[dict, list], Union[dict, list]]:
""" Return matching and nonmatching values. Keeps hierarchy. """
if isinstance(x, dict):
matching_vals = {}
nonmatching_vals = {}
for k, v in x.items():
if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate)
if match:
matching_vals[k] = match
if nonmatch or not v:
nonmatching_vals[k] = nonmatch
elif predicate(v):
matching_vals[k] = v
else:
nonmatching_vals[k] = v
else:
assert isinstance(x, list)
matching_vals = []
nonmatching_vals = []
for v in x:
if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate)
if match:
matching_vals.append(match)
if nonmatch or not v:
nonmatching_vals.append(nonmatch)
elif predicate(v):
matching_vals.append(v)
else:
nonmatching_vals.append(v)
return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
for k in x2.keys() & x1.keys():
_left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
elif isinstance(x1, list) and isinstance(x2, list):
only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)):
_left, _right, _mismatch = diff(v1, v2, prefix + (i,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
else:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
_is_mismatch = not torch.all(x1 == x2)
else:
try:
_is_mismatch = bool(x1 != x2)
except RuntimeError:
_is_mismatch = True
if _is_mismatch:
mismatch.append((prefix, type(x1), type(x2)))
return only_left, only_right, mismatch
def inspect_keys_types(d: dict, prefix: Tuple = (), indent: int = 4):
print_indent = lambda: print(' ' * indent * len(prefix), end='')
for k, v in d.items():
if isinstance(v, dict):
print_indent()
print(f'> {k}:')
inspect_keys_types(v, prefix + (k,), indent)
else:
print_indent()
if isinstance(v, torch.Tensor):
print(f'> {k}: {type(v)} of shape {v.shape}')
else:
print(f'> {k}: {type(v)}')
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
print_indent = lambda: print(' ' * indent * len(prefix), end='')
if isinstance(x, dict):
print()
for k, v in x.items():
print_indent()
print(f'> {k}: ', end='')
inspect_types(v, prefix + (k,), indent)
elif isinstance(x, list):
print()
for i, v in enumerate(x):
print_indent()
print(f'- {i}: ', end='')
inspect_types(v, prefix + (i,), indent)
else:
if isinstance(x, torch.Tensor):
print(f'Tensor of shape {x.shape}')
else:
try:
x_str = str(x)
except:
x_str = '<no string repr>'
if len(x_str) > 30:
x_str = x_str[:30] + '... (truncated)'
print(f'[{type(x)}]: {x_str}')
def nested_values(x: Union[dict, list]):
x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_values(v)
else:
yield v
def nested_items_iter(x: Union[dict, list]):
x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_items_iter(v)
else:
yield x, k, v
def dict_map(f: Callable, d: dict):
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict):
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable, x: Union[dict, list]):
if isinstance(x, dict):
for k, v in x.items():
x[k] = dict_list_map_inplace(f, v)
elif isinstance(x, list):
x[:] = (dict_list_map_inplace(f, v) for v in x)
else:
return f(x)
return x
def dict_list_map_outplace(f: Callable, x: Union[dict, list]):
if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list):
return [dict_list_map_outplace(f, v) for v in x]
else:
return f(x)
def merge(x1: dict, x2: dict):
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
x1[k] = v2
else:
x1[k] = merge(x1[k], v2)
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError('Cannot merge two lists with different lengths')
for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2)
else:
raise ValueError(f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}`')
return x1
def map_reduce(
xs: Iterable,
key_fn: Callable = lambda x: x,
value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x,
) -> dict:
res = defaultdict(list)
for x in xs:
res[key_fn(x)].append(value_fn(x))
for k in res:
res[k] = reduce_fn(res[k])
return dict(res)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Core library classes. """
from dataclasses import dataclass, replace
from itertools import chain
from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace
# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]]
@dataclass
class ShardedTensor:
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Attributes:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor, specified
in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt. local
tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor
to reflect global tensor shape.
The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of a
stored tensor does not have to match the expected global shape.
Useful for representing tensors with flexible shape, e.g. padded.
flattened_range: specifies a slice that should be applied to a flattened
tensor with `local_shape` in order to get the tensor stored as `data`
"""
key: str
data: Optional[torch.Tensor]
dtype: torch.dtype
local_shape: Tuple[int, ...]
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
axis_fragmentations: Optional[Tuple[int, ...]]
replica_id: ReplicaId = 0
prepend_axis_num: int = 0
allow_shape_mismatch: bool = False
flattened_range: Optional[slice] = None
def global_slice(self) -> Tuple[Union[int, slice], ...]:
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
return tuple(
chain(
(off for off in self.global_offset[: self.prepend_axis_num]),
(
slice(off, off + sh)
for off, sh in zip(
self.global_offset[self.prepend_axis_num :], self.local_shape
)
),
)
)
def global_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`global_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
local_coords = self.local_coordinates()
assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), (
len(local_coords),
self,
)
global_coords = tuple(
c + off
for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset)
)
return global_coords
def local_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`local_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
# TODO: np.unravel_index?
mask = np.zeros(np.product(self.local_shape), dtype=bool)
mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape))
def max_allowed_chunks(self) -> Tuple[int, ...]:
chunks = []
for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
raise CheckpointingException(
f'Axis shape ({axis_sh}) not divisible' f' by axis fragmentation ({axis_fragm}'
)
axis_chunk_size = axis_sh // axis_fragm
chunks.append(axis_chunk_size)
return tuple(chunks)
def without_data(self):
return replace(self, data=None)
@classmethod
def from_rank_offsets(
cls,
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: ReplicaId = 0,
prepend_axis_num: int = 0,
allow_shape_mismatch: bool = False,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Arguments:
key: unique key
data: local tensor data
rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm)
says that if global tensor is divided into `axis_fragm`
fragment along `axis` axis, then local tensor data
corresponds to the `axis_rank_offset` chunk.
replica_id: see ShardedTensor
prepend_axis_num: see ShardedTensor
allow_shape_mismatch: see ShardedTensor
"""
global_offset = [0] * (data.ndim + prepend_axis_num)
global_shape = ([1] * prepend_axis_num) + list(data.shape)
axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
_seen_axis = set()
for axis, axis_rank_offset, axis_fragm in rank_offsets:
assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, (
axis,
axis_rank_offset,
axis_fragm,
)
assert (
axis_rank_offset < axis_fragm
), 'Rank offset must be lower than axis fragmentation'
if axis in _seen_axis:
raise CheckpointingException('Duplicated axis specified')
_seen_axis.add(axis)
local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
global_shape[axis] = axis_fragm * local_axis_shape
global_offset[axis] = axis_rank_offset * local_axis_shape
axis_fragmentations[axis] = axis_fragm
return cls(
key,
data,
data.dtype,
tuple(data.shape),
tuple(global_shape),
tuple(global_offset),
tuple(axis_fragmentations),
replica_id,
prepend_axis_num,
allow_shape_mismatch,
)
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
def is_main_replica(replica_id):
if isinstance(replica_id, int):
return replica_id == 0
return all(r == 0 for r in replica_id)
class LocalNonpersitentObject:
"""Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersitentObject
will result in:
- during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict
"""
def __init__(self, obj):
self.obj = obj
def unwrap(self):
return self.obj
@dataclass
class ShardedObject:
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
between different processes.
NOTE: Contrary to ShardedTensor, it's impossible to change global object
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements.
Attributes:
key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation
global_shape: global object shape
global_offset: offset of a local object in a global object, specified
in number of shards
replica_id: indicates local object replication wrt. local
objects in different processes
"""
key: str
data: object
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
replica_id: ReplicaId = 0
def without_data(self):
return replace(self, data=None)
@property
def unique_key(self):
return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}'
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
@dataclass
class ShardedTensorFactory:
""" Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params.
Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading.
"""
key: str
data: torch.Tensor
build_fn: Callable[[str, torch.Tensor], ShardedStateDict]
merge_fn: Callable[[StateDict], torch.Tensor]
def build(self):
return self.build_fn(self.key, self.data)
def apply_factories(sharded_state_dict: ShardedStateDict):
def apply(x):
if isinstance(x, ShardedTensorFactory):
x = x.build()
return x
dict_list_map_inplace(apply, sharded_state_dict)
def apply_factory_merges(x1: StateDict, x2: ShardedStateDict):
if isinstance(x2, ShardedTensorFactory):
return x2.merge_fn(x1)
# There rest is almost the same as the `merge` function from `dict_utils`
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
raise ValueError('Different dict keys encountered in `apply_factory_merges`')
else:
x1[k] = apply_factory_merges(x1[k], v2)
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError('Cannot merge two lists with different lengths')
for i, v2 in enumerate(x2):
x1[i] = apply_factory_merges(x1[i], v2)
else:
raise ValueError(f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}`')
return x1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Optimizer related helpers. """
import logging
from copy import deepcopy
from dataclasses import replace
from itertools import chain
from typing import Dict, Iterable, List, Union
logger = logging.getLogger(__name__)
import torch
from .dict_utils import nested_values
from .mapping import (
LocalNonpersitentObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
from .utils import extract_sharded_tensors, extract_sharded_tensors_and_factories
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
param_mappings = {}
for i, param in enumerate(optim_params_iter):
if id(param) not in param_mappings:
param_mappings[id(param)] = i
return param_mappings
def get_param_id_to_sharded_param_map(
model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
else:
logger.debug(f'{ten} is not tracked by the optimizer')
if not id_to_sharded_param_map:
logger.warning(
"Sharded parameters mapping is empty. It means tensors in model state dict"
" do not correspond to tensors in optimizer parameters map."
" Make sure to call state_dict with `keep_vars=True`."
)
return id_to_sharded_param_map
def make_sharded_optimizer_tensor(
model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
) -> Union[ShardedTensor, ShardedTensorFactory]:
if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
assert (
tuple(optim_param.shape) == model_param.local_shape
), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
return replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
)
def optim_state_to_sharding_state(
optim_state_dict: StateDict, id_to_sharded_param_map: Dict[int, ShardedTensor]
):
sharded_state = {}
for param_id, param_state in optim_state_dict['state'].items():
sharded_state[param_id] = {}
for state_key, param in param_state.items():
if param_id in id_to_sharded_param_map:
sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
)
else:
raise ValueError(f'Param id {param_id} does not match any model sharded param')
optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
for group in optim_state_dict['param_groups']:
group['params'] = LocalNonpersitentObject(group['params'])
optim_state_dict['state'] = sharded_state
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
import logging
import os
from collections import Counter, defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingConfig, maybe_load_config, save_config
from .dict_utils import (
dict_list_map_inplace,
diff,
extract_matching_values,
map_reduce,
merge,
nested_values,
)
from .mapping import (
CheckpointingException,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
is_main_replica,
)
from .strategies.base import (
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent
COMMON_STATE_FNAME = 'common.pt'
logger = logging.getLogger(__name__)
def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, None] = None,
common_strategy: Union[LoadCommonStrategy, None] = None,
) -> StateDict:
"""Loading entrypoint.
Arguments:
sharded_state_dict: state dict of the existing model populated with
ShardedTensors. Used as a mapping to determine which parts of
global tensors stored in the checkpoint should be loaded.
checkpoint_dir: directory with the checkpoint
sharded_strategy: configures loading behavior for sharded tensors
common_strategy: configures loading behavior for common data
"""
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = load_common_state_dict(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir)
merge(common_state_dict, sharded_objects)
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict, lambda x: isinstance(x, ShardedTensorFactory)
)
apply_factories(sharded_state_dict)
sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
merge(common_state_dict, nonpersistent_state_dict)
validate_sharding_integrity(nested_values(sharded_state_dict))
if sharded_strategy is None:
sharded_strategy = get_default_strategy(
StrategyAction.LOAD_SHARDED,
saved_config.sharded_backend,
saved_config.sharded_backend_version,
)
else:
# TODO: implement consistency checks here
pass
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories)
merge(common_state_dict, loaded_state_dict)
return common_state_dict
# TODO: implement it as common torch strategy
def load_common_state_dict(checkpoint_dir: Path):
return torch.load(Path(checkpoint_dir) / COMMON_STATE_FNAME, map_location='cpu')
def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
sharded_objects, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
loaded_obj = torch.load(load_path)
return loaded_obj
return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict
def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, None] = None,
common_strategy: Union[SaveCommonStrategy, None] = None,
):
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Arguments:
sharded_state_dict: state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint.
checkpoint_dir: directory to save the checkpoint to
sharded_strategy: configures sharded tensors saving behavior and backend
common_strategy: configures common data saving behavior and backend
"""
checkpoint_dir = Path(checkpoint_dir)
if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
raise CheckpointingException(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
if sharded_strategy is None:
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'zarr', 1)
apply_factories(sharded_state_dict)
sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict)
sharded_tensors = list(nested_values(sharded_state_dict))
validate_sharding_integrity(sharded_tensors)
_save_common_dict(state_dict, checkpoint_dir, True)
sharded_strategy.save(sharded_tensors, checkpoint_dir)
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir
)
# TODO: implement it as common torch strategy
def _save_common_dict(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
common_state_dict = _extract_and_save_sharded_objects(
state_dict, checkpoint_dir, validate_consistency
)
if torch.distributed.get_rank() == 0:
torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)
if validate_consistency:
# TODO: implement checking consistency with rank 0 common dict on other ranks
pass
# torch.distributed.barrier()
# if not torch.distributed.get_rank() == 0:
# rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME)
# print(diff(common_state_dict, rank_0_state_dict))
def _extract_and_save_sharded_objects(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
sharded_objects, state_dict = extract_matching_values(
state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = list(nested_values(sharded_objects))
if validate_consistency:
validate_objects_sharding_integrity(sharded_objects)
for sh_obj in sharded_objects:
if is_main_replica(sh_obj.replica_id):
save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
return state_dict
def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
sharding = [ten.without_data() for ten in sharded_tensors]
all_sharding = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_sharding, sharding)
if torch.distributed.get_rank() != 0:
return
key_shardings = defaultdict(list)
for rank, rank_shardings in enumerate(all_sharding):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
_validate_sharding_for_key(shardings)
def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
some_rank_shard = rank_sharding[0][1]
global_shape = some_rank_shard.global_shape
local_shape = some_rank_shard.local_shape
dtype = some_rank_shard.dtype
has_flattened_range = some_rank_shard.flattened_range is not None
for rank, sharding in rank_sharding:
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
assert sharding.global_shape == global_shape, (
sharding.global_shape,
global_shape,
some_rank_shard,
)
assert sharding.local_shape == local_shape, (
sharding.local_shape,
local_shape,
some_rank_shard,
)
assert (sharding.flattened_range is not None) == has_flattened_range, (
(sharding.flattened_range is not None),
has_flattened_range,
some_rank_shard,
)
shard_access_cnt = _compute_shards_access(rank_sharding)
if has_flattened_range:
map_reduce(
rank_sharding,
lambda x: x[1].global_offset,
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
def _compute_shards_access(rank_sharding):
def chunk_offset(sharding):
assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num
return tuple(
chain(
(off for off in sharding.global_offset[: sharding.prepend_axis_num]),
(
off // sh
for off, sh in zip(
sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape
)
),
)
)
shard_access_cnt = torch.zeros(
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
)
for rank, sharding in rank_sharding:
if is_main_replica(sharding.replica_id):
shard_access_cnt[chunk_offset(sharding)] += 1
# TODO: consider validating different replicas too
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
# TODO: this checks only saving (and loading replica_id=0) consistency
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}'
)
def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]):
""" Ensure uniqueness of saved objects. """
local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects]
all_sh_objs = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_sh_objs, local_sh_objs)
if torch.distributed.get_rank() != 0:
return
unique_keys = [
sh_obj.unique_key
for sh_obj in chain.from_iterable(all_sh_objs)
if is_main_replica(sh_obj.replica_id)
]
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment