Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Pai-Megatron-Patch
Commits
deb8370c
Commit
deb8370c
authored
Jan 09, 2025
by
hepj
Browse files
Initial commit
parents
Pipeline
#2198
canceled with stages
Changes
321
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1846 additions
and
0 deletions
+1846
-0
PAI-Megatron-LM-240718/megatron/core/distributed/finalize_model_grads.py
...-240718/megatron/core/distributed/finalize_model_grads.py
+150
-0
PAI-Megatron-LM-240718/megatron/core/distributed/param_and_grad_buffer.py
...240718/megatron/core/distributed/param_and_grad_buffer.py
+549
-0
PAI-Megatron-LM-240718/megatron/core/enums.py
PAI-Megatron-LM-240718/megatron/core/enums.py
+10
-0
PAI-Megatron-LM-240718/megatron/core/fusions/__init__.py
PAI-Megatron-LM-240718/megatron/core/fusions/__init__.py
+0
-0
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_dropout.py
...ron-LM-240718/megatron/core/fusions/fused_bias_dropout.py
+73
-0
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_geglu.py
...atron-LM-240718/megatron/core/fusions/fused_bias_geglu.py
+85
-0
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_gelu.py
...gatron-LM-240718/megatron/core/fusions/fused_bias_gelu.py
+50
-0
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_swiglu.py
...tron-LM-240718/megatron/core/fusions/fused_bias_swiglu.py
+89
-0
PAI-Megatron-LM-240718/megatron/core/fusions/fused_cross_entropy.py
...on-LM-240718/megatron/core/fusions/fused_cross_entropy.py
+153
-0
PAI-Megatron-LM-240718/megatron/core/fusions/fused_layer_norm.py
...atron-LM-240718/megatron/core/fusions/fused_layer_norm.py
+169
-0
PAI-Megatron-LM-240718/megatron/core/fusions/fused_softmax.py
...Megatron-LM-240718/megatron/core/fusions/fused_softmax.py
+220
-0
PAI-Megatron-LM-240718/megatron/core/inference/__init__.py
PAI-Megatron-LM-240718/megatron/core/inference/__init__.py
+1
-0
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/__init__.py
...M-240718/megatron/core/inference/ammo_support/__init__.py
+0
-0
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/gpt/__init__.py
...0718/megatron/core/inference/ammo_support/gpt/__init__.py
+1
-0
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/gpt/model_specs.py
...8/megatron/core/inference/ammo_support/gpt/model_specs.py
+58
-0
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/gpt/state_dict_hooks.py
...atron/core/inference/ammo_support/gpt/state_dict_hooks.py
+145
-0
PAI-Megatron-LM-240718/megatron/core/inference/common_inference_params.py
...240718/megatron/core/inference/common_inference_params.py
+28
-0
PAI-Megatron-LM-240718/megatron/core/inference/communication_utils.py
...-LM-240718/megatron/core/inference/communication_utils.py
+49
-0
PAI-Megatron-LM-240718/megatron/core/inference/engines/__init__.py
...ron-LM-240718/megatron/core/inference/engines/__init__.py
+0
-0
PAI-Megatron-LM-240718/megatron/core/inference/engines/abstract_engine.py
...240718/megatron/core/inference/engines/abstract_engine.py
+16
-0
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PAI-Megatron-LM-240718/megatron/core/distributed/finalize_model_grads.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
Optional
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
..
import
parallel_state
from
..transformer.transformer_config
import
TransformerConfig
from
..utils
import
get_attr_wrapped_model
,
get_model_config
def
_allreduce_word_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if
(
parallel_state
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
torch
.
distributed
.
get_world_size
(
parallel_state
.
get_embedding_group
())
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support an interleaved schedule for models with encoders yet.
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
:
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad
=
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
def
_allreduce_position_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position
embeddings parameters stay in sync.
"""
if
(
parallel_state
.
is_rank_in_position_embedding_group
()
and
torch
.
distributed
.
get_world_size
(
parallel_state
.
get_position_embedding_group
())
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support an interleaved schedule for models with encoders yet.
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
assert
hasattr
(
model_module
,
'position_embeddings'
)
grad
=
model_module
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_position_embedding_group
())
def
_allreduce_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads
(
model
,
config
)
_allreduce_position_embedding_grads
(
model
,
config
)
def
_allreduce_layernorm_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
parallel_state
.
get_tensor_model_parallel_world_size
()
>
1
and
(
config
.
sequence_parallel
or
config
.
qk_layernorm
):
grads
=
[]
for
model_chunk
in
model
:
for
name
,
param
in
get_attr_wrapped_model
(
model_chunk
,
'named_parameters'
)():
if
(
param
.
requires_grad
and
getattr
(
param
,
'sequence_parallel'
,
False
)
or
'q_layernorm'
in
name
or
'k_layernorm'
in
name
):
grad
=
param
.
main_grad
grads
.
append
(
grad
.
data
)
if
grads
:
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
parallel_state
.
get_tensor_model_parallel_group
()
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
finalize_model_grads
(
model
:
List
[
torch
.
nn
.
Module
],
num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config
=
get_model_config
(
model
[
0
])
# All-reduce / reduce-scatter across DP replicas.
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
for
model_chunk
in
model
:
model_chunk
.
finish_grad_sync
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
).
stop
()
# All-reduce layer-norm grads (for sequence parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_layernorm_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
).
stop
()
# All-reduce embedding grads (for pipeline parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_embedding_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
).
stop
()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if
num_tokens
is
not
None
:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
torch
.
distributed
.
broadcast
(
num_tokens
,
src
=
parallel_state
.
get_pipeline_model_parallel_last_rank
(),
group
=
parallel_state
.
get_pipeline_model_parallel_group
(),
)
# all-reduce across DP ranks.
torch
.
distributed
.
all_reduce
(
num_tokens
,
group
=
parallel_state
.
get_data_parallel_group
())
for
model_chunk
in
model
:
if
num_tokens
>
0
:
scaling
=
1.0
/
num_tokens
model_chunk
.
scale_gradients
(
scaling
)
PAI-Megatron-LM-240718/megatron/core/distributed/param_and_grad_buffer.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
import
math
import
os
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
import
torch
from
..utils
import
log_on_each_pipeline_stage
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
logger
=
logging
.
getLogger
(
__name__
)
class
BufferType
(
Enum
):
PARAM
=
1
GRAD
=
2
def
shard_buffer
(
buffer
:
torch
.
Tensor
,
data_parallel_world_size
:
int
):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert
buffer
.
numel
()
%
data_parallel_world_size
==
0
shard_size
=
buffer
.
numel
()
//
data_parallel_world_size
sharded_buffer
=
[
buffer
[(
r
*
shard_size
)
:
((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)
]
return
sharded_buffer
class
Bucket
:
"""
Bucket to keep track of a subset of the model's gradients. Provides functionality to register
when params in the bucket have grads ready to be synced; an asynchronous communication call
is automatically launched when _all_ params in the bucket have grads ready.
Args:
ddp_config: DistributedDataParallel config object.
params: List of parameters whose gradients are collated in this bucket.
param_data: View in larger ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in larger ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
data_parallel_group: Data-parallel process group.
data_parallel_world_size: World size using the data-parallel group group.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def
__init__
(
self
,
ddp_config
:
DistributedDataParallelConfig
,
params
:
List
[
torch
.
nn
.
Parameter
],
param_data
:
Optional
[
torch
.
Tensor
],
grad_data
:
torch
.
Tensor
,
offset
:
int
,
numel_unpadded
:
int
,
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
data_parallel_world_size
:
int
,
gradient_scaling_factor
:
float
,
):
self
.
ddp_config
=
ddp_config
# State for bookkeeping: params is the set of parameters this bucket is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self
.
params_list
=
params
self
.
params
=
set
(
params
)
self
.
params_with_grad
=
set
()
self
.
param_data
=
param_data
self
.
grad_data
=
grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self
.
offset
=
offset
self
.
numel_unpadded
=
numel_unpadded
self
.
data_parallel_group
=
data_parallel_group
self
.
data_parallel_world_size
=
data_parallel_world_size
self
.
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
group
=
data_parallel_group
)
self
.
gradient_scaling_factor
=
gradient_scaling_factor
self
.
reset
()
def
reset
(
self
):
"""
Reset metadata in bucket in preparation for the next iteration of training.
"""
self
.
params_with_grad
=
set
()
self
.
communication_handle
=
None
self
.
is_communication_outstanding
=
False
def
start_grad_sync
(
self
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert
(
self
.
communication_handle
is
None
and
not
self
.
is_communication_outstanding
),
'Should not have multiple communication calls outstanding at once'
# Make sure norm of grads in bucket are not NaN
# prior to data-parallel all-reduce / reduce-scatter.
if
self
.
ddp_config
.
check_for_nan_in_grad
:
global_rank
=
torch
.
distributed
.
get_rank
()
norm
=
self
.
grad_data
.
norm
(
p
=
2
)
assert
not
norm
.
isnan
(),
(
f
'Rank
{
global_rank
}
: found NaN in local grad norm in '
f
'backward pass before data-parallel communication collective. '
f
'Device:
{
torch
.
cuda
.
current_device
()
}
, node:
{
os
.
uname
()[
1
]
}
'
)
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
if
self
.
gradient_scaling_factor
!=
1.0
:
self
.
grad_data
*=
self
.
gradient_scaling_factor
# Decide reduce_op.
reduce_op
=
torch
.
distributed
.
ReduceOp
.
SUM
if
self
.
ddp_config
.
average_in_collective
:
reduce_op
=
torch
.
distributed
.
ReduceOp
.
AVG
# Use async_op only when overlap_grad_reduce is True.
if
self
.
ddp_config
.
use_distributed_optimizer
:
local_data_view
=
shard_buffer
(
self
.
grad_data
,
self
.
data_parallel_world_size
)[
self
.
data_parallel_rank
]
self
.
communication_handle
=
torch
.
distributed
.
_reduce_scatter_base
(
local_data_view
,
self
.
grad_data
,
op
=
reduce_op
,
group
=
self
.
data_parallel_group
,
async_op
=
self
.
ddp_config
.
overlap_grad_reduce
,
)
else
:
self
.
communication_handle
=
torch
.
distributed
.
all_reduce
(
self
.
grad_data
,
op
=
reduce_op
,
group
=
self
.
data_parallel_group
,
async_op
=
self
.
ddp_config
.
overlap_grad_reduce
,
)
if
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
is_communication_outstanding
=
True
else
:
self
.
is_communication_outstanding
=
False
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, waits for asynchronous communication
call to complete. When overlap_grad_reduce is set to False, makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if
not
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
start_grad_sync
()
return
assert
self
.
communication_handle
is
not
None
and
self
.
is_communication_outstanding
,
(
f
'Communication call has not been issued for this bucket '
f
'(
{
len
(
self
.
params_with_grad
)
}
/
{
len
(
self
.
params
)
}
params have grad available)'
)
self
.
communication_handle
.
wait
()
def
register_grad_ready
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert
param
in
self
.
params
,
'Param is not in the bucket'
assert
param
not
in
self
.
params_with_grad
,
'Cannot set grad twice'
assert
(
self
.
ddp_config
.
overlap_grad_reduce
),
'register_grad_ready() should be called only when overlapping grad reduce'
self
.
params_with_grad
.
add
(
param
)
# If all params in bucket have grads available, issue communication call.
if
len
(
self
.
params_with_grad
)
==
len
(
self
.
params
):
self
.
start_grad_sync
()
class
ParamAndGradBuffer
:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def
__init__
(
self
,
ddp_config
:
DistributedDataParallelConfig
,
param_dtype
:
torch
.
dtype
,
grad_dtype
:
torch
.
dtype
,
params
:
List
[
torch
.
nn
.
Parameter
],
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
bucket_size
:
int
,
param_to_name
:
Dict
[
torch
.
nn
.
Parameter
,
str
],
gradient_scaling_factor
:
float
,
):
self
.
ddp_config
=
ddp_config
# Check that params are unique.
unique_params
=
set
()
for
param
in
params
:
assert
param
not
in
unique_params
unique_params
.
add
(
param
)
del
unique_params
# Store attributes that will be needed later.
self
.
param_dtype
=
param_dtype
self
.
grad_dtype
=
grad_dtype
self
.
data_parallel_group
=
data_parallel_group
self
.
data_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
data_parallel_group
)
self
.
gradient_scaling_factor
=
gradient_scaling_factor
self
.
is_last_microbatch
=
True
# Data structures to store underlying buckets and relevant indexing data.
self
.
buckets
=
[]
self
.
param_to_bucket
=
{}
# Param -> bucket mapping.
self
.
param_index_map
=
{}
# Param -> location in buffer mapping (used in dist. optimizer).
def
_pad
(
number_to_be_padded
:
int
,
divisor
:
int
)
->
int
:
return
int
(
math
.
ceil
(
number_to_be_padded
/
divisor
)
*
divisor
)
def
_pad_end_of_bucket_if_needed
(
bucket_end_index
:
int
)
->
int
:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return
_pad
(
bucket_end_index
,
math
.
lcm
(
self
.
data_parallel_world_size
,
128
))
return
bucket_end_index
def
_pad_start_of_param_if_needed
(
param_start_index
:
int
)
->
int
:
"""
Pads start index of param if using distributed optimizer (to ensure "good" alignment).
"""
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Ensure that params start at 128-byte aligned addresses (64 values
# since params are >= 16-bit precision).
return
_pad
(
param_start_index
,
64
)
return
param_start_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
data_start_index
=
0
bucket_data_start_index
=
data_start_index
bucket_params
=
set
()
self
.
bucket_indices
=
[]
per_bucket_numel_unpadded
=
[]
bucket_id
=
0
def
_create_new_bucket
(
data_end_index
:
int
)
->
int
:
"""
Create the bucket_id'th bucket with collected bucket_params, starting at
bucket_data_start_index.
"""
nonlocal
bucket_data_start_index
,
bucket_params
,
bucket_id
per_bucket_numel_unpadded
.
append
(
data_end_index
-
bucket_data_start_index
)
data_end_index
=
_pad_end_of_bucket_if_needed
(
data_end_index
)
# Update bucket metadata.
self
.
bucket_indices
.
append
((
bucket_data_start_index
,
data_end_index
))
bucket_data_start_index
=
data_end_index
# Re-set bucket_params and increment bucket_id for next bucket.
bucket_params
=
set
()
bucket_id
+=
1
# Return the potentially padded data_end_index.
return
data_end_index
for
param
in
params
[::
-
1
]:
# Iterate through parameters in reverse order to roughly follow backprop order,
# and skip parameters that don't require gradients.
if
not
param
.
requires_grad
:
continue
this_numel
=
param
.
data
.
nelement
()
data_start_index
=
_pad_start_of_param_if_needed
(
data_start_index
)
data_end_index
=
data_start_index
+
this_numel
def
_does_param_require_new_bucket
(
param
):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return
(
getattr
(
param
,
"shared_embedding"
,
False
)
and
self
.
ddp_config
.
use_distributed_optimizer
)
# Create bucket with already collected parameters if current param needs its own bucket.
if
_does_param_require_new_bucket
(
param
)
and
len
(
bucket_params
)
>
0
:
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current data_start_index.
if
self
.
ddp_config
.
use_distributed_optimizer
:
# data_start_index should already be padded.
assert
data_start_index
%
self
.
data_parallel_world_size
==
0
_create_new_bucket
(
data_start_index
)
self
.
param_index_map
[
param
]
=
(
data_start_index
,
data_end_index
,
bucket_id
,
)
bucket_params
.
add
(
param
)
# If we have enough elements already or the current param is part of the shared embedding
# layer and needs a separate bucket, form a new bucket.
if
(
bucket_size
is
not
None
and
(
data_end_index
-
bucket_data_start_index
)
>=
bucket_size
)
or
_does_param_require_new_bucket
(
param
):
data_end_index
=
_create_new_bucket
(
data_end_index
)
data_start_index
=
data_end_index
# Add remaining params to a new bucket.
if
len
(
bucket_params
)
>
0
:
data_end_index
=
_create_new_bucket
(
data_end_index
)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self
.
numel
=
data_end_index
self
.
numel_unpadded
=
sum
(
per_bucket_numel_unpadded
)
assert
self
.
numel_unpadded
<=
self
.
numel
if
self
.
ddp_config
.
use_distributed_optimizer
:
assert
self
.
numel
%
self
.
data_parallel_world_size
==
0
else
:
assert
self
.
numel
==
self
.
numel_unpadded
self
.
param_data
=
None
# Only re-map param tensors if using distributed optimizer.
if
self
.
ddp_config
.
use_distributed_optimizer
:
self
.
param_data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
param_dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
self
.
grad_data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
grad_dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params
=
set
()
bucket_data_start_index
=
0
cur_bucket_id
=
0
for
param
in
params
[::
-
1
]:
if
not
param
.
requires_grad
:
continue
data_start_index
,
data_end_index
,
bucket_id
=
self
.
param_index_map
[
param
]
# Assign param.data to appropriate segment of self.param_data.
if
self
.
param_data
is
not
None
:
old_param_data
=
param
.
data
param
.
data
=
self
.
_get
(
param
.
data
.
shape
,
data_start_index
,
buffer_type
=
BufferType
.
PARAM
)
assert
old_param_data
.
_base
is
None
# Copy tensor values (from initialization or checkpoint).
param
.
data
.
detach
().
copy_
(
old_param_data
)
del
old_param_data
param
.
main_grad
=
self
.
_get
(
param
.
data
.
shape
,
data_start_index
,
buffer_type
=
BufferType
.
GRAD
)
if
bucket_id
!=
cur_bucket_id
:
bucket_data_end_index
=
_pad_end_of_bucket_if_needed
(
data_start_index
)
self
.
_set_bucket
(
bucket_params
=
bucket_params
,
start_index
=
bucket_data_start_index
,
end_index
=
bucket_data_end_index
,
numel_unpadded
=
per_bucket_numel_unpadded
[
cur_bucket_id
],
bucket_id
=
cur_bucket_id
,
)
bucket_data_start_index
=
bucket_data_end_index
bucket_params
=
set
()
assert
cur_bucket_id
+
1
==
len
(
self
.
buckets
)
assert
bucket_id
==
cur_bucket_id
+
1
cur_bucket_id
=
bucket_id
bucket_params
.
add
(
param
)
# Add remaining params to a new bucket.
if
len
(
bucket_params
)
>
0
:
bucket_data_end_index
=
_pad_end_of_bucket_if_needed
(
data_end_index
)
self
.
_set_bucket
(
bucket_params
=
bucket_params
,
start_index
=
bucket_data_start_index
,
end_index
=
bucket_data_end_index
,
numel_unpadded
=
per_bucket_numel_unpadded
[
cur_bucket_id
],
bucket_id
=
cur_bucket_id
,
)
# Log buckets for all PP stages.
log_strs
=
[]
log_strs
.
append
(
f
'Number of buckets for gradient all-reduce / reduce-scatter:
{
len
(
self
.
buckets
)
}
'
)
for
index
,
bucket
in
enumerate
(
self
.
buckets
):
numel
=
0
for
param
in
bucket
.
params
:
numel
+=
param
.
data
.
nelement
()
log_strs
.
append
(
f
'Params for bucket
{
index
+
1
}
(
{
numel
}
elements):'
)
for
param
in
bucket
.
params
:
log_strs
.
append
(
f
'
\t
{
param_to_name
[
param
]
}
'
)
log_on_each_pipeline_stage
(
logger
,
logging
.
INFO
,
'
\n
'
.
join
(
log_strs
))
def
scale_gradients
(
self
,
scaling_factor
:
float
)
->
None
:
"""Scale the gradient data by `scaling_factor`."""
self
.
grad_data
*=
scaling_factor
def
_get
(
self
,
shape
:
torch
.
Size
,
start_index
:
int
,
buffer_type
:
BufferType
)
->
torch
.
Tensor
:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
'Requested tensor is out of buffer range'
if
buffer_type
==
BufferType
.
PARAM
:
assert
self
.
param_data
is
not
None
buffer_tensor
=
self
.
param_data
[
start_index
:
end_index
]
elif
buffer_type
==
BufferType
.
GRAD
:
buffer_tensor
=
self
.
grad_data
[
start_index
:
end_index
]
else
:
raise
Exception
(
"Illegal buffer type provided to GradBuffer._get() function"
)
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
def
_set_bucket
(
self
,
bucket_params
:
List
[
torch
.
nn
.
Parameter
],
start_index
:
int
,
end_index
:
int
,
numel_unpadded
:
int
,
bucket_id
:
int
,
):
"""
Helper function to create new bucket, add it to list of buckets, and
also update param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if
self
.
ddp_config
.
use_distributed_optimizer
:
assert
start_index
%
self
.
data_parallel_world_size
==
0
assert
end_index
%
self
.
data_parallel_world_size
==
0
assert
(
start_index
,
end_index
)
==
self
.
bucket_indices
[
bucket_id
]
# Get appropriate view into global ParamAndGradBuffer.
bucketed_param_data
=
None
if
self
.
param_data
is
not
None
:
bucketed_param_data
=
self
.
_get
(
torch
.
Size
([
end_index
-
start_index
]),
start_index
,
buffer_type
=
BufferType
.
PARAM
)
bucketed_grad_data
=
self
.
_get
(
torch
.
Size
([
end_index
-
start_index
]),
start_index
,
buffer_type
=
BufferType
.
GRAD
)
bucket
=
Bucket
(
ddp_config
=
self
.
ddp_config
,
params
=
bucket_params
,
param_data
=
bucketed_param_data
,
grad_data
=
bucketed_grad_data
,
offset
=
start_index
,
numel_unpadded
=
numel_unpadded
,
data_parallel_group
=
self
.
data_parallel_group
,
data_parallel_world_size
=
self
.
data_parallel_world_size
,
gradient_scaling_factor
=
self
.
gradient_scaling_factor
,
)
self
.
buckets
.
append
(
bucket
)
for
bucket_param
in
bucket_params
:
assert
bucket_param
not
in
self
.
param_to_bucket
self
.
param_to_bucket
[
bucket_param
]
=
bucket
def
reset
(
self
):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation for the next
iteration of training.
"""
self
.
grad_data
.
zero_
()
for
bucket
in
self
.
buckets
:
bucket
.
reset
()
self
.
is_last_microbatch
=
True
def
start_grad_sync
(
self
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
bucket
in
self
.
buckets
:
bucket
.
start_grad_sync
()
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
bucket
in
self
.
buckets
:
bucket
.
finish_grad_sync
()
def
register_grad_ready
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert
(
self
.
ddp_config
.
overlap_grad_reduce
),
'register_grad_ready() should only be called when overlap_grad_reduce is True'
if
self
.
is_last_microbatch
:
bucket
=
self
.
param_to_bucket
[
param
]
bucket
.
register_grad_ready
(
param
)
PAI-Megatron-LM-240718/megatron/core/enums.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
enum
class
ModelType
(
enum
.
Enum
):
encoder_or_decoder
=
1
encoder_and_decoder
=
2
retro_encoder
=
3
retro_decoder
=
4
PAI-Megatron-LM-240718/megatron/core/fusions/__init__.py
0 → 100644
View file @
deb8370c
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_dropout.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
,
Tuple
import
torch
from
megatron.core.jit
import
jit_fuser
def
_bias_dropout_add_func
(
x_with_bias
,
residual
,
prob
,
training
):
# type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
x
,
bias
=
x_with_bias
# unpack
# If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x.
residual
=
residual
if
residual
.
dtype
==
x
.
dtype
else
residual
.
to
(
x
.
dtype
)
# The Dropout operation, Residual Addition and the tensor returning can be
# done generically outside the if statement, but that stops fusing of Bias
# Addition-Dropout-Residual Addition operation. So doing it together inside
# the conditional branch to improve performance
if
bias
is
not
None
:
x
=
x
+
bias
out
=
torch
.
nn
.
functional
.
dropout
(
x
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
x
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
bias_dropout_add_unfused
(
training
):
def
_bias_dropout_add
(
x_with_bias
,
residual
,
prob
):
return
_bias_dropout_add_func
(
x_with_bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
jit_fuser
def
bias_dropout_add_fused_train
(
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
,
)
->
torch
.
Tensor
:
return
_bias_dropout_add_func
(
x_with_bias
,
residual
,
prob
,
True
)
@
jit_fuser
def
bias_dropout_add_fused_inference
(
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
,
)
->
torch
.
Tensor
:
return
_bias_dropout_add_func
(
x_with_bias
,
residual
,
prob
,
False
)
def
get_bias_dropout_add
(
training
,
fused
):
if
fused
:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
training
:
return
bias_dropout_add_fused_train
else
:
return
bias_dropout_add_fused_inference
else
:
return
bias_dropout_add_unfused
(
training
)
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_geglu.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
from
megatron.core.jit
import
jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
jit_fuser
def
geglu
(
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
(
y_1
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
y_1
*
(
1
+
0.044715
*
y_1
*
y_1
))))
*
y_2
@
jit_fuser
def
bias_geglu
(
bias
,
y
):
y
=
y
+
bias
return
geglu
(
y
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
jit_fuser
def
geglu_back
(
g
,
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
tanh_out
=
torch
.
tanh
(
0.79788456
*
y_1
*
(
1
+
0.044715
*
y_1
*
y_1
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
y_1
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
y_1
*
y_1
))
+
0.5
*
(
1
+
tanh_out
)
return
torch
.
cat
(((
g
*
y_2
)
*
ff
,
g
*
(
y_1
*
0.5
*
(
1.0
+
tanh_out
))),
-
1
)
@
jit_fuser
def
bias_geglu_back
(
g
,
y
,
bias
):
y
=
y
+
bias
return
geglu_back
(
g
,
y
)
class
BiasGeGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_geglu
(
input
,
bias
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_geglu_back
(
grad_output
,
input
,
bias
)
return
tmp
,
tmp
class
GeGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
):
ctx
.
save_for_backward
(
input
)
return
geglu
(
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
=
ctx
.
saved_tensors
tmp
=
geglu_back
(
grad_output
,
input
[
0
])
return
tmp
def
bias_geglu_impl
(
input
,
bias
):
ori_shape
=
input
.
shape
assert
len
(
ori_shape
)
in
[
2
,
3
]
input
=
input
.
view
(
-
1
,
ori_shape
[
-
1
])
if
bias
is
not
None
:
output
=
BiasGeGLUFunction
.
apply
(
input
,
bias
)
else
:
output
=
GeGLUFunction
.
apply
(
input
)
return
output
if
len
(
ori_shape
)
==
2
else
output
.
view
(
ori_shape
[
0
],
ori_shape
[
1
],
-
1
)
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_gelu.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
from
megatron.core.jit
import
jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
jit_fuser
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
jit_fuser
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
))
+
0.5
*
(
1
+
tanh_out
)
return
ff
*
g
class
GeLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_gelu
(
bias
,
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
bias_gelu_impl
=
GeLUFunction
.
apply
PAI-Megatron-LM-240718/megatron/core/fusions/fused_bias_swiglu.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
megatron.core.jit
import
jit_fuser
###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################
@
jit_fuser
def
swiglu
(
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
F
.
silu
(
y_1
)
*
y_2
@
jit_fuser
def
bias_swiglu
(
y
,
bias
):
y
=
y
+
bias
return
swiglu
(
y
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
jit_fuser
def
swiglu_back
(
g
,
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
torch
.
cat
(
(
g
*
torch
.
sigmoid
(
y_1
)
*
(
1
+
y_1
*
(
1
-
torch
.
sigmoid
(
y_1
)))
*
y_2
,
g
*
F
.
silu
(
y_1
)),
-
1
)
@
jit_fuser
def
bias_swiglu_back
(
g
,
y
,
bias
):
y
=
y
+
bias
return
swiglu_back
(
g
,
y
)
class
BiasSwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
,
fp8_input_store
):
input_for_backward
=
input
.
to
(
torch
.
float8_e4m3fn
)
if
fp8_input_store
else
input
ctx
.
save_for_backward
(
input_for_backward
,
bias
)
ctx
.
ori_input_dtype
=
input
.
dtype
ctx
.
fp8_input_store
=
fp8_input_store
return
bias_swiglu
(
input
,
bias
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
input
=
input
.
to
(
ctx
.
ori_input_dtype
)
if
ctx
.
fp8_input_store
else
input
tmp
=
bias_swiglu_back
(
grad_output
,
input
,
bias
)
return
tmp
,
tmp
,
None
class
SwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
fp8_input_store
):
input_for_backward
=
input
.
to
(
torch
.
float8_e4m3fn
)
if
fp8_input_store
else
input
ctx
.
save_for_backward
(
input_for_backward
)
ctx
.
ori_input_dtype
=
input
.
dtype
ctx
.
fp8_input_store
=
fp8_input_store
return
swiglu
(
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
=
ctx
.
saved_tensors
[
0
]
input
=
input
.
to
(
ctx
.
ori_input_dtype
)
if
ctx
.
fp8_input_store
else
input
tmp
=
swiglu_back
(
grad_output
,
input
)
return
tmp
,
None
def
bias_swiglu_impl
(
input
,
bias
,
fp8_input_store
=
False
):
ori_shape
=
input
.
shape
assert
len
(
ori_shape
)
in
[
2
,
3
]
input
=
input
.
view
(
-
1
,
ori_shape
[
-
1
])
if
bias
is
not
None
:
output
=
BiasSwiGLUFunction
.
apply
(
input
,
bias
,
fp8_input_store
)
else
:
output
=
SwiGLUFunction
.
apply
(
input
,
fp8_input_store
)
return
output
if
len
(
ori_shape
)
==
2
else
output
.
view
(
ori_shape
[
0
],
ori_shape
[
1
],
-
1
)
# bias_swiglu_impl = BiasSwiGLUFunction.apply
# swiglu_impl = SwiGLUFunction.apply
PAI-Megatron-LM-240718/megatron/core/fusions/fused_cross_entropy.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Tuple
import
torch
from
megatron.core.jit
import
jit_fuser
from
megatron.core.parallel_state
import
(
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
megatron.core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
megatron.core.tensor_parallel.utils
import
VocabUtility
@
jit_fuser
def
calculate_logits_max
(
vocab_parallel_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vocab_parallel_logits
,
logits_max
=
VocabParallelCrossEntropy
.
calculate_logits_max
(
vocab_parallel_logits
)
return
vocab_parallel_logits
,
logits_max
@
jit_fuser
def
calculate_predicted_logits
(
vocab_parallel_logits
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
logits_max
:
torch
.
Tensor
,
vocab_start_index
:
int
,
vocab_end_index
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
(
target_mask
,
masked_target_1d
,
predicted_logits
,
sum_exp_logits
,
exp_logits
,
)
=
VocabParallelCrossEntropy
.
calculate_predicted_logits
(
vocab_parallel_logits
,
target
,
logits_max
,
vocab_start_index
,
vocab_end_index
)
predicted_logits_sum_exp_logits
=
torch
.
cat
((
predicted_logits
,
sum_exp_logits
))
return
target_mask
,
masked_target_1d
,
predicted_logits_sum_exp_logits
,
exp_logits
@
jit_fuser
def
calculate_cross_entropy_loss
(
exp_logits
:
torch
.
Tensor
,
predicted_logits_sum_exp_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
split_val
=
predicted_logits_sum_exp_logits
.
size
()[
0
]
//
2
predicted_logits
,
sum_exp_logits
=
torch
.
split
(
predicted_logits_sum_exp_logits
,
split_val
)
exp_logits
,
loss
=
VocabParallelCrossEntropy
.
calculate_cross_entropy_loss
(
exp_logits
,
predicted_logits
,
sum_exp_logits
)
return
exp_logits
,
loss
@
jit_fuser
def
calculate_gradients
(
softmax
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
,
target_mask
:
torch
.
Tensor
,
masked_target_1d
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
(
grad_2d
,
arange_1d
,
softmax_update
,
grad_input
,
)
=
VocabParallelCrossEntropy
.
prepare_gradient_calculation_operands
(
softmax
,
target_mask
)
grad_input
=
VocabParallelCrossEntropy
.
calculate_gradients
(
grad_2d
,
arange_1d
,
masked_target_1d
,
softmax_update
,
grad_input
,
grad_output
)
grad_input
=
grad_input
.
to
(
torch
.
bfloat16
)
return
grad_input
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
vocab_parallel_logits
,
logits_max
=
calculate_logits_max
(
vocab_parallel_logits
)
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_tensor_model_parallel_group
()
)
# Get the partition's vocab indices
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
partition_vocab_size
,
rank
,
world_size
)
(
target_mask
,
masked_target_1d
,
predicted_logits_sum_exp_logits
,
exp_logits
,
)
=
calculate_predicted_logits
(
vocab_parallel_logits
,
target
,
logits_max
,
vocab_start_index
,
vocab_end_index
)
# All reduce is needed to get the chunks from other GPUs.
# In the fused case, tensors are batches to invoke a single
# AllReduce call
torch
.
distributed
.
all_reduce
(
predicted_logits_sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_tensor_model_parallel_group
(),
)
exp_logits
,
loss
=
calculate_cross_entropy_loss
(
exp_logits
,
predicted_logits_sum_exp_logits
)
# Store softmax, target-mask and masked-target for backward pass.
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
return
loss
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Retreive tensors from the forward path.
softmax
,
target_mask
,
masked_target_1d
=
ctx
.
saved_tensors
grad_input
=
calculate_gradients
(
softmax
,
grad_output
,
target_mask
,
masked_target_1d
)
return
grad_input
,
None
def
fused_vocab_parallel_cross_entropy
(
vocab_parallel_logits
,
target
):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
"""
return
_VocabParallelCrossEntropy
.
apply
(
vocab_parallel_logits
,
target
)
PAI-Megatron-LM-240718/megatron/core/fusions/fused_layer_norm.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
importlib
import
inspect
import
numbers
import
torch
from
torch
import
Tensor
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
megatron.core.transformer
import
TransformerConfig
from
megatron.core.utils
import
make_viewless_tensor
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
:
HAVE_PERSIST_LAYER_NORM
=
False
try
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM
=
True
except
:
HAVE_FUSED_LAYER_NORM
=
False
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
"""Layer Norm, fused into a single CUDA kernel.
Args:
hidden_size (int): Transformer hidden dimension.
eps (float): Epsilon added to denominator, for numerical stability.
persist_layer_norm (bool): Use persistent fused layer norm kernel.
This kernel supports only a set of hidden sizes. Please
check persist_ln_hidden_sizes if your hidden size is supported.
zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.
config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.
normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
hidden_size
:
int
,
eps
:
float
=
1e-5
,
persist_layer_norm
:
bool
=
True
,
zero_centered_gamma
:
bool
=
False
,
normalization
:
str
=
"LayerNorm"
,
# included to match TE interface
):
super
().
__init__
()
self
.
config
=
config
self
.
zero_centered_gamma
=
self
.
config
.
layernorm_zero_centered_gamma
assert
(
self
.
config
.
normalization
==
"LayerNorm"
),
f
'(
{
self
.
config
.
normalization
}
) is not supported in FusedLayerNorm'
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
,
]
persist_layer_norm
=
self
.
config
.
persist_layer_norm
if
hidden_size
not
in
persist_ln_hidden_sizes
or
not
HAVE_PERSIST_LAYER_NORM
:
persist_layer_norm
=
False
if
not
persist_layer_norm
and
not
HAVE_FUSED_LAYER_NORM
:
# TODO: Add pytorch only layer norm
raise
ValueError
(
f
'Apex must be installed to use FusedLayerNorm.'
)
if
isinstance
(
hidden_size
,
numbers
.
Integral
):
hidden_size
=
(
hidden_size
,)
self
.
hidden_size
=
torch
.
Size
(
hidden_size
)
self
.
eps
=
eps
# Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
self
.
weight
=
Parameter
(
torch
.
empty
(
*
hidden_size
))
self
.
bias
=
Parameter
(
torch
.
empty
(
*
hidden_size
))
self
.
reset_parameters
()
self
.
persist_layer_norm
=
persist_layer_norm
self
.
sequence_parallel
=
self
.
config
.
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
def
reset_parameters
(
self
):
if
self
.
zero_centered_gamma
:
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
else
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
if
self
.
persist_layer_norm
:
if
'memory_efficient'
in
inspect
.
getfullargspec
(
FastLayerNormFN
.
forward
).
args
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
,
self
.
config
.
memory_efficient_layer_norm
)
else
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
input
.
requires_grad
,
keep_graph
=
True
)
else
:
if
(
'memory_efficient'
in
inspect
.
getfullargspec
(
FusedLayerNormAffineFunction
.
forward
).
args
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
hidden_size
,
self
.
eps
,
self
.
config
.
memory_efficient_layer_norm
,
)
else
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
hidden_size
,
self
.
eps
)
return
output
PAI-Megatron-LM-240718/megatron/core/fusions/fused_softmax.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.utils
import
get_default_causal_mask
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
class
ScaledSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Args:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]):
"""Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed.
A user-defined mask is only needed when attn_mask_type is not causal.
"""
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
4096
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
4096
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
if
mask
is
not
None
:
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
return
ScaledSoftmax
.
apply
(
input
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
# Generate causal mask if not given
sq
,
sk
=
input
.
size
(
2
),
input
.
size
(
3
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
and
mask
is
None
and
sq
>
1
:
# If sq == 1 then either KV cache is used or one-element context is passed
# so keeping mask=None in this case; subsequent code should handle it
assert
sq
==
sk
,
"causal mask is only for self attention"
mask
=
get_default_causal_mask
(
sq
)
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
PAI-Megatron-LM-240718/megatron/core/inference/__init__.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/__init__.py
0 → 100644
View file @
deb8370c
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/gpt/__init__.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/gpt/model_specs.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.custom_layers.transformer_engine
import
TEDotProductAttention
,
TENorm
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def
get_gpt_layer_modelopt_spec
(
remap_te_layernorm
:
bool
=
False
,
qk_layernorm
:
bool
=
False
)
->
ModuleSpec
:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
sharded_state_dict_keys_map
=
{}
if
remap_te_layernorm
:
sharded_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
}
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
,
mlp
=
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
ColumnParallelLinear
,
linear_fc2
=
RowParallelLinear
,
),
),
mlp_bda
=
get_bias_dropout_add
,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map
=
sharded_state_dict_keys_map
,
),
)
PAI-Megatron-LM-240718/megatron/core/inference/ammo_support/gpt/state_dict_hooks.py
0 → 100644
View file @
deb8370c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
logging
import
getLogger
import
torch
logger
=
getLogger
(
__name__
)
def
mcore_gpt_load_legacy_state_dict_pre_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
,
):
"""Register a pre-hook to fix the state_dict key difference.
This prehook is used when trying to load the legacy Megatron-LM GPTModel into its
megatron/core variant that uses native ParallelLinear and Transformer-Engine Norm.
Only this particular spec supports post-training quantization and TensorRT-LLM
config export through `nvidia-modelopt` package.
Args:
state_dict: state dictionary
prefix: module name prefix
local_metadata: local metatdata
strict: whether is in strict mode
missing_keys: missing state dict keys
unexpected_keys: unexpected state dict keys
error_msgs: error messages
"""
if
"modelopt_state"
in
state_dict
:
state_dict
.
pop
(
"modelopt_state"
)
if
"language_model"
in
state_dict
:
language_model_state_dict
=
state_dict
.
pop
(
"language_model"
)
if
"embedding"
in
language_model_state_dict
:
if
"word_embeddings"
in
language_model_state_dict
[
"embedding"
]:
for
key
,
param
in
language_model_state_dict
[
"embedding"
][
"word_embeddings"
].
items
():
state_dict
.
update
({
"embedding.word_embeddings."
+
key
:
param
})
if
"position_embeddings"
in
language_model_state_dict
[
"embedding"
]:
for
key
,
param
in
language_model_state_dict
[
"embedding"
][
"position_embeddings"
].
items
():
state_dict
.
update
({
"embedding.position_embeddings."
+
key
:
param
})
if
"transformer"
in
language_model_state_dict
:
for
key
,
param
in
language_model_state_dict
[
"transformer"
].
items
():
state_dict
.
update
({
"decoder."
+
key
:
param
})
else
:
for
key
,
param
in
language_model_state_dict
[
"encoder"
].
items
():
state_dict
.
update
({
"decoder."
+
key
:
param
})
if
"output_layer"
in
language_model_state_dict
:
for
key
,
param
in
language_model_state_dict
[
"output_layer"
].
items
():
state_dict
.
update
({
"output_layer."
+
key
:
param
})
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"ModelOptGPTModel {}"
.
format
(
state_dict
.
keys
()))
module_name_rewrite_list
=
[
(
"input_norm"
,
"input_layernorm"
),
(
".attention.query_key_value"
,
".self_attention.linear_qkv"
),
(
".attention.dense"
,
".self_attention.linear_proj"
),
(
"self_attention.query_key_value"
,
"self_attention.linear_qkv"
),
(
"self_attention.dense"
,
"self_attention.linear_proj"
),
(
"post_attention_layernorm"
,
"pre_mlp_layernorm"
),
(
"post_attention_norm"
,
"pre_mlp_layernorm"
),
(
"dense_h_to_4h"
,
"linear_fc1"
),
(
"dense_4h_to_h"
,
"linear_fc2"
),
(
"final_norm"
,
"final_layernorm"
),
]
key_rewrite_list
=
[]
for
key
,
_
in
state_dict
.
items
():
for
old_name
,
new_name
in
module_name_rewrite_list
:
if
old_name
in
key
:
key_rewrite_list
+=
[(
key
,
key
.
replace
(
old_name
,
new_name
))]
for
old_key
,
new_key
in
key_rewrite_list
:
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"replace {} with {}"
.
format
(
old_key
,
new_key
))
state_dict
[
new_key
]
=
state_dict
[
old_key
]
state_dict
.
pop
(
old_key
)
def
mcore_gpt_load_te_state_dict_pre_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
,
):
"""Register a pre-hook to fix the state_dict key difference of.
This prehook is used when trying to load the megatron/core GPTModel that uses a
fused Transformer-Engine ParallelLinear into the variant that uses native ParallelLinear
and Transformer-Engine Norm (effectively to restore the fusion).
Only this particular spec supports post-training quantization and TensorRT-LLM
config export through `nvidia-modelopt` package.
Args:
state_dict: state dictionary
prefix: module name prefix
local_metadata: local metatdata
strict: whether is in strict mode
missing_keys: missing state dict keys
unexpected_keys: unexpected state dict keys
error_msgs: error messages
"""
if
"modelopt_state"
in
state_dict
:
state_dict
.
pop
(
"modelopt_state"
)
key_with_te_extra_state_to_pop
=
[]
for
key
,
_
in
state_dict
.
items
():
if
"_extra_state"
in
key
:
key_with_te_extra_state_to_pop
+=
[
key
]
for
key
in
key_with_te_extra_state_to_pop
:
state_dict
.
pop
(
key
)
module_name_rewrite_list
=
[
(
"self_attention.linear_qkv.layer_norm_weight"
,
"input_layernorm.weight"
),
(
"self_attention.linear_qkv.layer_norm_bias"
,
"input_layernorm.bias"
),
(
"mlp.linear_fc1.layer_norm_weight"
,
"pre_mlp_layernorm.weight"
),
(
"mlp.linear_fc1.layer_norm_bias"
,
"pre_mlp_layernorm.bias"
),
]
key_rewrite_list
=
[]
for
key
,
_
in
state_dict
.
items
():
for
old_name
,
new_name
in
module_name_rewrite_list
:
if
old_name
in
key
:
key_rewrite_list
+=
[(
key
,
key
.
replace
(
old_name
,
new_name
))]
for
old_key
,
new_key
in
key_rewrite_list
:
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"replace {} with {}"
.
format
(
old_key
,
new_key
))
state_dict
[
new_key
]
=
state_dict
[
old_key
]
state_dict
.
pop
(
old_key
)
PAI-Megatron-LM-240718/megatron/core/inference/common_inference_params.py
0 → 100644
View file @
deb8370c
from
dataclasses
import
dataclass
@
dataclass
class
CommonInferenceParams
:
"""Inference parameters sent along with the prompts
For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910
"""
temperature
:
float
=
1.0
top_k
:
int
=
0
top_p
:
float
=
0.0
return_log_probs
:
bool
=
False
num_tokens_to_generate
:
int
=
30
def
add_attributes
(
self
,
attribute_value_pair
:
dict
):
"""Utility to add more attributes to inference params
Use this method to pass in a custom dictonary to add more inference parameter attributes to the instance you created. Use as follows
c = CommonInferenceParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and their values as the values.
"""
for
key
,
value
in
attribute_value_pair
.
items
():
setattr
(
self
,
key
,
value
)
PAI-Megatron-LM-240718/megatron/core/inference/communication_utils.py
0 → 100644
View file @
deb8370c
import
torch
from
megatron.core
import
parallel_state
def
_is_cuda
(
tensor
):
"""Check if a tensor is not none and is cuda."""
assert
tensor
is
not
None
assert
tensor
.
is_cuda
def
broadcast_from_last_pipeline_stage
(
size
,
dtype
,
tensor
=
None
):
"""Broadcast a tensor from last pipeline stage to all ranks."""
if
parallel_state
.
is_pipeline_last_stage
():
_is_cuda
(
tensor
)
assert
tensor
.
is_contiguous
()
else
:
tensor
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Get the group and corresponding source rank.
src
=
parallel_state
.
get_pipeline_model_parallel_last_rank
()
group
=
parallel_state
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
tensor
,
src
,
group
)
return
tensor
def
recv_from_prev_pipeline_rank_
(
recv_buffer
=
None
):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
recv_buffer
,
parallel_state
.
get_pipeline_model_parallel_prev_rank
()
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
([
recv_prev_op
])
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
def
send_to_next_pipeline_rank
(
tensor
=
None
):
"""Send output to the next pipeline stage."""
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor
,
parallel_state
.
get_pipeline_model_parallel_next_rank
()
)
reqs
=
torch
.
distributed
.
batch_isend_irecv
([
send_next_op
])
for
req
in
reqs
:
req
.
wait
()
# To protect against race condition when using batch_isend_irecv().
torch
.
cuda
.
synchronize
()
PAI-Megatron-LM-240718/megatron/core/inference/engines/__init__.py
0 → 100644
View file @
deb8370c
PAI-Megatron-LM-240718/megatron/core/inference/engines/abstract_engine.py
0 → 100644
View file @
deb8370c
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
class
AbstractEngine
(
ABC
):
@
staticmethod
@
abstractmethod
def
generate
(
self
)
->
dict
:
"""The abstract backend's generate function.
To define a new backend, implement this and return the outputs as a dictionary.
Returns:
dict: The output dictionary containing keys for `input_prompt`, `generated_text`, `generated_tokens`.
"""
pass
Prev
1
…
7
8
9
10
11
12
13
14
15
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment